Skip to content

Work-Stealing-based Persistent Kernel#64

Open
neoblizz wants to merge 22 commits intomainfrom
neoblizz/work-stealing
Open

Work-Stealing-based Persistent Kernel#64
neoblizz wants to merge 22 commits intomainfrom
neoblizz/work-stealing

Conversation

@neoblizz
Copy link
Member

@neoblizz neoblizz commented Feb 5, 2026

Motivation

Dynamically take away tile ids instead of fixed partitioning.

Getting Started

git clone -b neoblizz/work-stealing https://github.com/ROCm/tritonBLAS
cd tritonBLAS
pip install -e .

# Install latest triton
git clone https://github.com/triton-lang/triton
cd triton
pip install -e .

# Work-stealing CU sweep (304 to 32 CUs)
python benchmarks/tritonblas_matmul.py \
    --input-yaml datasets/bench_8k.yaml \
    --work-stealing \
    --cu-sweep \
    --cu-sweep-max-remove 34 \
    --counters-per-xcd 1 \
    --output-csv results_ws_cu_sweep.csv

python benchmarks/torch_matmul.py \
    --input-yaml datasets/bench_8k.yaml \
    --cu-sweep \
    --cu-sweep-max-remove 34 \
    --output-csv results_torch_cu_sweep.csv

python tools/plot_cu_sweep.py \
    --persistent results_persistent_sweep.csv \
    --torch      results_torch_cu_sweep.csv \
    --ws-cpc 1   results_ws_cu_sweep.csv \
    -o cu_sweep_plot.png

Copilot AI review requested due to automatic review settings February 5, 2026 20:34
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR introduces a work-stealing-based persistent GEMM kernel that dynamically allocates tile IDs across compute units instead of using fixed partitioning. The implementation uses per-XCD (chiplet) atomic counters to reduce contention compared to global atomic operations. The work-stealing kernel is exposed as an opt-in feature through a new work_stealing parameter in the matmul APIs.

Changes:

  • Added MatmulConfig class to pre-allocate and manage GPU buffers for kernel launches (tile counters, stream-K locks/partials)
  • Implemented work-stealing kernel with per-XCD atomic tile counters in persistent_gemm_work_stealing.py
  • Extended all matmul APIs with optional work_stealing and config parameters to support the new kernel

Reviewed changes

Copilot reviewed 6 out of 6 changed files in this pull request and generated 22 comments.

Show a summary per file
File Description
include/tritonblas/matmul.py Added MatmulConfig class for buffer management; integrated work_stealing parameter and ws_persistent_matmul kernel; refactored buffer allocation to use config objects
include/tritonblas/kernels/persistent_gemm_work_stealing.py New work-stealing kernel implementation with per-XCD atomic counters and dynamic tile assignment
include/tritonblas/kernels/__init__.py Exported ws_persistent_matmul kernel
include/tritonblas/__init__.py Exported MatmulConfig and matmul_preamble to public API
tests/test_work_stealing.py Standalone test with custom module loading to test work-stealing kernel correctness and performance
benchmarks/benchmark_work_stealing.py Comprehensive benchmark comparing work-stealing against static persistent, stream-K, and torch.matmul

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@ryanswann-amd ryanswann-amd self-requested a review February 12, 2026 17:29
Copy link
Collaborator

@ryanswann-amd ryanswann-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need to think more about how we intend people to use matmul.

b: torch.Tensor,
c: torch.Tensor,
selector,
config: MatmulConfig,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty confident this makes the matmul call non torch like. Do we want the user to manage the locks tensor (by passing it in via the MatmulConfig object) or do we want tritonBLAS to manage it internally (like hipblasLT)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@asunderwood is a good example of a torch user. Thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This particular change isn't an issue because, from a torch API perspective, the user won't be calling persistent_matmul_lt() themselves. They'll instead call matmul() which does have a new, non-torch kwarg but it's not the first we've added and as long as it has a default value allowing a user to skip it then it retains torch compatibility.

Comment on lines 188 to 228
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This results in allocation every call when not using config, right? If I just call C = tritonblas.matmul(A,B)

else:
locks = torch.empty(grids, device="cuda", dtype=torch.uint8)
P = torch.empty(grids, block_size, device="cuda", dtype=torch.float32)
locks = torch.empty(grids, device=cfg.device, dtype=torch.uint8)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This results in an allocation every call to C=tritonblas.matmul(A,B) if I don't pre-allocate buffers. This means to get performance (and not have an allocation each time) users have to pass in the pre-allocated buffers, which we were trying to avoid in the previous approach.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Aren't there already locks in the config being passed in? Shouldn't we just use those instead of re-allocating?

neoblizz and others added 9 commits February 18, 2026 21:46
Resolve merge conflicts from 'Support async copy (#72)':

- origami.py: Keep both work-stealing params (total_cus, active_cus) and
  main's num_stages param. Keep work-stealing's multi-version origami API
  handling for select_workgroup_mapping over main's simpler version.

- matmul.py: Use main's getattr(selector, "num_stages", 2) for proper
  num_stages propagation. Keep work-stealing's pre-allocated buffer
  optimization for locks/P but use a.device (from main) in fallback path.
Copy link
Collaborator

@ryanswann-amd ryanswann-amd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good other than the weird diff file.

arch.diff Outdated
Comment on lines +1 to +12
diff --git a/shared/origami/python/origami_module.cpp b/shared/origami/python/origami_module.cpp
index a85c5da..c4b9b07 100644
--- a/shared/origami/python/origami_module.cpp
+++ b/shared/origami/python/origami_module.cpp
@@ -154,6 +154,7 @@ NB_MODULE(origami, m) {
size_t,
std::tuple<double, double, double>>())
.def("print", &hardware_t::print)
+ .def_rw("arch", &hardware_t::arch)
.def_rw("N_CU", &hardware_t::N_CU)
.def_rw("lds_capacity", &hardware_t::lds_capacity)
.def_rw("mem1_perf_ratio", &hardware_t::mem1_perf_ratio)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not commit diff files.

Point setup.py to ryaswann/tritonblas_expose_arch branch of
rocm-libraries which includes the .def_rw("arch", &hardware_t::arch)
change. Remove arch.diff and the git apply step since the change is
now baked into the upstream commit.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants