Skip to content

[Midend] Add INT8 support to AMX matmul pass with unified dtype detection#633

Open
Jiajun-Ji wants to merge 13 commits intobuddy-compiler:mainfrom
Jiajun-Ji:AMX-INT8
Open

[Midend] Add INT8 support to AMX matmul pass with unified dtype detection#633
Jiajun-Ji wants to merge 13 commits intobuddy-compiler:mainfrom
Jiajun-Ji:AMX-INT8

Conversation

@Jiajun-Ji
Copy link
Contributor

This PR extends the AMX matmul optimization pass to support INT8 data types alongside BF16.

The unified -matmul-amx pass now automatically detects input data types (bf16×bf16→f32 or i8×i8→i32) and applies the appropriate AMX tile operations with VNNI packing.

Properly pack B matrix into VNNI format with interleaved rows for correct AMX BF16 tile operations.
Enable hardware-accelerated matrix multiplication for DeepSeek-R1 using Intel AMX instructions with proper VNNI layout conversion.
This directory was added accidentally and should not be part of the project.
…nalysis in AMX matmul test for precision loss comparison
…nalysis in AMX matmul test for precision loss comparison
Copilot AI review requested due to automatic review settings December 23, 2025 07:10
Copy link

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR extends the AMX matmul optimization pass to support INT8 data types alongside BF16, creating a unified -matmul-amx pass that automatically detects input data types and applies appropriate AMX tile operations with VNNI packing.

Key changes:

  • Added automatic dtype detection (bf16×bf16→f32 or i8×i8→i32) in the AMX matmul pass
  • Implemented INT8-specific VNNI packing and dimension validation for AMX operations
  • Extended DeepSeekR1 example with BF16 support including prefill/decode phases

Reviewed changes

Copilot reviewed 14 out of 14 changed files in this pull request and generated 8 comments.

Show a summary per file
File Description
riscv-ime-extension-spec Added subproject commit reference
midend/lib/Conversion/MatMulOptimization/MatmulAMX.cpp Unified AMX pass with BF16/INT8 dtype detection, separate VNNI packing functions, and dimension validators
frontend/Python/ops/tosa.py Added bf16 support in addmm_op and fixed attribute access in scaled_dot_product
examples/BuddyMatmul/test-amx-int8-pass.mlir New test file demonstrating INT8 AMX conversion via unified pass
examples/BuddyMatmul/test-amx-bf16-pass.mlir New test file demonstrating BF16 AMX conversion via unified pass
examples/BuddyMatmul/makefile Updated build targets for INT8/BF16 testing and added missing linker flag
examples/BuddyMatmul/linalg-to-amx-example.mlir Removed deprecated example file
examples/BuddyMatmul/linalg-bf16-matmul.mlir Removed deprecated test file
examples/BuddyMatmul/amx-wrapper.c Removed redundant AMX_DATA permission request
examples/BuddyMatmul/amx-int8-matmul.mlir Added comprehensive INT8 AMX matmul example with VNNI packing
examples/BuddyMatmul/amx-bf16-matmul.mlir Enhanced BF16 example with cyclic initialization and VNNI packing
examples/BuddyDeepSeekR1/import-deepseek-r1.py Extended import script for BF16 with prefill/decode phases
examples/BuddyDeepSeekR1/buddy-deepseek-r1-bf16-main.cpp Implemented BF16 inference with AMX support, KV cache management, and prefill/decode separation
examples/BuddyDeepSeekR1/CMakeLists.txt Updated build pipeline for BF16 with AMX flags and separate prefill/decode compilation

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.


matmul_op = tosa.ClampOp(
matmul_op.result.type,
matmul_op.type,
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed from matmul_op.result.type to matmul_op.type. Verify that matmul_op is a Value with a .type attribute, not an Operation with a .result.type attribute. If matmul_op is an operation, this will fail at runtime.

Copilot uses AI. Check for mistakes.
constexpr size_t ParamsSize = 1777088064;
constexpr size_t MaxVocabSize = 151936;
constexpr size_t MaxTokenLength = 40;
constexpr size_t MaxTokenLength = 32;
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

MaxTokenLength reduced from 40 to 32 without corresponding adjustment in comments or documentation. Ensure this aligns with the model's cache configuration (lines 121, 124 reference max_cache_len=32) and update any references to the old value of 40.

Copilot uses AI. Check for mistakes.
Comment on lines +344 to +345
const std::string vocabDir = deepSeekR1Dir + "vocab.txt";
const std::string paramsDir = deepSeekR1BuildDir + "arg0-bf16.data";
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the forward slash separator between directory and filename. This will cause file path concatenation errors if deepSeekR1Dir or deepSeekR1BuildDir don't end with a trailing slash. The compile definitions at lines 685-686 now add trailing slashes, but this creates a fragile dependency.

Suggested change
const std::string vocabDir = deepSeekR1Dir + "vocab.txt";
const std::string paramsDir = deepSeekR1BuildDir + "arg0-bf16.data";
const std::string vocabDir =
(std::filesystem::path(deepSeekR1Dir) / "vocab.txt").string();
const std::string paramsDir =
(std::filesystem::path(deepSeekR1BuildDir) / "arg0-bf16.data").string();

Copilot uses AI. Check for mistakes.
Comment on lines 167 to 173
SmallVector<Value> dynamicSizes;
SmallVector<int64_t> staticShape;

Value c0 = builder.create<arith::ConstantIndexOp>(loc, 0);
Value c1 = builder.create<arith::ConstantIndexOp>(loc, 1);
Value c2 = builder.create<arith::ConstantIndexOp>(loc, 2);

Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The constant c2 is created but c1 is already defined at line 171. Consider defining c2 alongside other constants at the beginning of the function for consistency with the existing pattern of defining c0, c1 together.

Suggested change
SmallVector<Value> dynamicSizes;
SmallVector<int64_t> staticShape;
Value c0 = builder.create<arith::ConstantIndexOp>(loc, 0);
Value c1 = builder.create<arith::ConstantIndexOp>(loc, 1);
Value c2 = builder.create<arith::ConstantIndexOp>(loc, 2);
Value c0 = builder.create<arith::ConstantIndexOp>(loc, 0);
Value c1 = builder.create<arith::ConstantIndexOp>(loc, 1);
Value c2 = builder.create<arith::ConstantIndexOp>(loc, 2);
SmallVector<Value> dynamicSizes;
SmallVector<int64_t> staticShape;

Copilot uses AI. Check for mistakes.
Comment on lines +121 to +124
config=model.config, max_cache_len=32
)
past_key_values_decode = StaticCache(
config=model.config, max_cache_len=20
config=model.config, max_cache_len=32
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The max_cache_len is hardcoded to 32 (changed from 20). Consider extracting this as a constant or configuration parameter to avoid inconsistencies with MaxTokenLength in the C++ code.

Copilot uses AI. Check for mistakes.
Comment on lines +260 to +261
@${BUDDY_OPT} ./test-amx-bf16-pass.mlir \
-matmul-amx \
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Target name 'linalg-bf16-matmul-lower' now operates on 'test-amx-bf16-pass.mlir' instead of 'linalg-bf16-matmul.mlir', which may confuse users expecting the target name to match the input file pattern.

Copilot uses AI. Check for mistakes.
maxIdx = i;
/// Find the index of the max value in BF16 array
int findMaxIndex(const uint16_t *start, const uint16_t *end) {
size_t size = end - start;
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed function signature from findMaxIndex(const uint16_t *start, size_t length) to use pointer range (start, end). While this is a valid design, the function now computes size from pointer subtraction but doesn't validate that end >= start, which could cause issues if called incorrectly.

Suggested change
size_t size = end - start;
// Validate pointer range to avoid undefined behavior from invalid ranges.
// Require a non-null, non-empty range with end strictly after start.
if (start == nullptr || end == nullptr || end <= start) {
return -1;
}
size_t size = static_cast<size_t>(end - start);

Copilot uses AI. Check for mistakes.
Comment on lines +430 to +433
Value c2Val =
builder.create<arith::ConstantIndexOp>(loc, 2);
Value kDiv2 = builder.create<arith::DivUIOp>(loc, k, c2Val);
Value nMul2 = builder.create<arith::MulIOp>(loc, n, c2Val);
Copy link

Copilot AI Dec 23, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creates a new constant c2Val inside the innermost loop. This constant creation should be hoisted outside the loop structure to avoid redundant operations in each iteration.

Copilot uses AI. Check for mistakes.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants