Conversation
There was a problem hiding this comment.
Pull request overview
This PR introduces a work-stealing-based persistent GEMM kernel that dynamically allocates tile IDs across compute units instead of using fixed partitioning. The implementation uses per-XCD (chiplet) atomic counters to reduce contention compared to global atomic operations. The work-stealing kernel is exposed as an opt-in feature through a new work_stealing parameter in the matmul APIs.
Changes:
- Added
MatmulConfigclass to pre-allocate and manage GPU buffers for kernel launches (tile counters, stream-K locks/partials) - Implemented work-stealing kernel with per-XCD atomic tile counters in
persistent_gemm_work_stealing.py - Extended all matmul APIs with optional
work_stealingandconfigparameters to support the new kernel
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 22 comments.
Show a summary per file
| File | Description |
|---|---|
include/tritonblas/matmul.py |
Added MatmulConfig class for buffer management; integrated work_stealing parameter and ws_persistent_matmul kernel; refactored buffer allocation to use config objects |
include/tritonblas/kernels/persistent_gemm_work_stealing.py |
New work-stealing kernel implementation with per-XCD atomic counters and dynamic tile assignment |
include/tritonblas/kernels/__init__.py |
Exported ws_persistent_matmul kernel |
include/tritonblas/__init__.py |
Exported MatmulConfig and matmul_preamble to public API |
tests/test_work_stealing.py |
Standalone test with custom module loading to test work-stealing kernel correctness and performance |
benchmarks/benchmark_work_stealing.py |
Comprehensive benchmark comparing work-stealing against static persistent, stream-K, and torch.matmul |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
ryanswann-amd
left a comment
There was a problem hiding this comment.
I think we need to think more about how we intend people to use matmul.
include/tritonblas/matmul.py
Outdated
| b: torch.Tensor, | ||
| c: torch.Tensor, | ||
| selector, | ||
| config: MatmulConfig, |
There was a problem hiding this comment.
I'm pretty confident this makes the matmul call non torch like. Do we want the user to manage the locks tensor (by passing it in via the MatmulConfig object) or do we want tritonBLAS to manage it internally (like hipblasLT)
There was a problem hiding this comment.
@asunderwood is a good example of a torch user. Thoughts?
There was a problem hiding this comment.
This particular change isn't an issue because, from a torch API perspective, the user won't be calling persistent_matmul_lt() themselves. They'll instead call matmul() which does have a new, non-torch kwarg but it's not the first we've added and as long as it has a default value allowing a user to skip it then it retains torch compatibility.
include/tritonblas/matmul.py
Outdated
There was a problem hiding this comment.
This results in allocation every call when not using config, right? If I just call C = tritonblas.matmul(A,B)
include/tritonblas/matmul.py
Outdated
| else: | ||
| locks = torch.empty(grids, device="cuda", dtype=torch.uint8) | ||
| P = torch.empty(grids, block_size, device="cuda", dtype=torch.float32) | ||
| locks = torch.empty(grids, device=cfg.device, dtype=torch.uint8) |
There was a problem hiding this comment.
This results in an allocation every call to C=tritonblas.matmul(A,B) if I don't pre-allocate buffers. This means to get performance (and not have an allocation each time) users have to pass in the pre-allocated buffers, which we were trying to avoid in the previous approach.
There was a problem hiding this comment.
Aren't there already locks in the config being passed in? Shouldn't we just use those instead of re-allocating?
Resolve merge conflicts from 'Support async copy (#72)': - origami.py: Keep both work-stealing params (total_cus, active_cus) and main's num_stages param. Keep work-stealing's multi-version origami API handling for select_workgroup_mapping over main's simpler version. - matmul.py: Use main's getattr(selector, "num_stages", 2) for proper num_stages propagation. Keep work-stealing's pre-allocated buffer optimization for locks/P but use a.device (from main) in fallback path.
ryanswann-amd
left a comment
There was a problem hiding this comment.
Looks good other than the weird diff file.
arch.diff
Outdated
| diff --git a/shared/origami/python/origami_module.cpp b/shared/origami/python/origami_module.cpp | ||
| index a85c5da..c4b9b07 100644 | ||
| --- a/shared/origami/python/origami_module.cpp | ||
| +++ b/shared/origami/python/origami_module.cpp | ||
| @@ -154,6 +154,7 @@ NB_MODULE(origami, m) { | ||
| size_t, | ||
| std::tuple<double, double, double>>()) | ||
| .def("print", &hardware_t::print) | ||
| + .def_rw("arch", &hardware_t::arch) | ||
| .def_rw("N_CU", &hardware_t::N_CU) | ||
| .def_rw("lds_capacity", &hardware_t::lds_capacity) | ||
| .def_rw("mem1_perf_ratio", &hardware_t::mem1_perf_ratio) |
There was a problem hiding this comment.
We should not commit diff files.
Point setup.py to ryaswann/tritonblas_expose_arch branch of
rocm-libraries which includes the .def_rw("arch", &hardware_t::arch)
change. Remove arch.diff and the git apply step since the change is
now baked into the upstream commit.
Motivation
Dynamically take away tile ids instead of fixed partitioning.
Getting Started