-
Notifications
You must be signed in to change notification settings - Fork 432
[Feature] Hierarchical reduction and warp reduction intrinsics support #1762
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
📝 WalkthroughWalkthroughAdds a hierarchical CUDA reduction path and dynamic AllReduce workspace sizing: introduces Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~45 minutes Possibly related issues
Possibly related PRs
Suggested reviewers
Poem
🚥 Pre-merge checks | ✅ 2 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🤖 Fix all issues with AI agents
In `@testing/python/language/test_tilelang_language_reduce.py`:
- Around line 99-105: The ref_fn uses torch.uint32 in a static dtype list which
raises AttributeError on PyTorch <2.3.0; update ref_fn to conditionally include
torch.uint32 only when hasattr(torch, "uint32") (or use getattr with a fallback)
so the dtype check is built at runtime, mirroring the existing pattern used for
version-dependent dtypes like float8_e4m3fn; locate and modify the dtype
membership test in ref_fn to construct the list/set conditionally and then
perform the same res.to(A.dtype) conversion for the supported integer dtypes
(torch.uint32, torch.int32, torch.int64).
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_reduce.py (1)
7-7: Consider scopingdisable_cache()to avoid global test side effects.
This flips a process-wide cache flag; if other tests run in the same session, they inherit the disabled cache. If that’s not intended, wrap it in a fixture/context that re-enables after this module.
| import torch | ||
|
|
||
| def ref_fn(A): | ||
| res = A.sum(dim=1) | ||
| if A.dtype in [torch.uint32, torch.int32, torch.int64]: | ||
| return res.to(A.dtype) | ||
| return res |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, find the root-level configuration files to determine PyTorch version requirements
fd -e "pyproject.toml" -e "requirements.txt" -e "setup.py" -e "setup.cfg" | head -20Repository: tile-ai/tilelang
Length of output: 42
🏁 Script executed:
# Search for PyTorch version requirements in the repository
rg -i "torch.*version|pytorch" --type toml --type txt --type py -A 2 | head -50Repository: tile-ai/tilelang
Length of output: 2715
🏁 Script executed:
# Check the specific test file and surrounding context
wc -l testing/python/language/test_tilelang_language_reduce.pyRepository: tile-ai/tilelang
Length of output: 120
🏁 Script executed:
# Look for other torch.uint32 usage in the codebase
rg "torch\.uint32" -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 3742
🏁 Script executed:
# Check for existing dtype guards or compatibility patterns
rg "hasattr.*torch" -B 2 -A 2Repository: tile-ai/tilelang
Length of output: 6548
🏁 Script executed:
# Look for version requirements
cat pyproject.toml 2>/dev/null | head -100Repository: tile-ai/tilelang
Length of output: 3126
🏁 Script executed:
# Check setup.py or requirements files
find . -maxdepth 2 -name "requirements*.txt" -o -name "setup.py" | xargs cat 2>/dev/null | grep -i torchRepository: tile-ai/tilelang
Length of output: 130
🌐 Web query:
When was torch.uint32 introduced in PyTorch - version history
💡 Result:
torch.uint32 was introduced in PyTorch 2.3.0 (released April 24, 2024) as an unsigned 32-bit “shell dtype” (available as a dtype object, but with limited eager-mode operator support). [1][2][3]
Before that, PyTorch generally did not support uint32/uint64 (beyond uint8). [4]
Sources:
[1] (dev-discuss.pytorch.org)
[2] (github.com)
[3] (docs.pytorch.org)
[4] (discuss.pytorch.org)
Citations:
- 1: https://dev-discuss.pytorch.org/t/pytorch-2-3-0-general-availability/2033?utm_source=openai
- 2: About uint16 support pytorch/vision#8359
- 3: https://docs.pytorch.org/docs/stable/tensor_attributes.html?utm_source=openai
- 4: https://discuss.pytorch.org/t/pytorch-uint32-uint64/78263?utm_source=openai
🏁 Script executed:
# Look at the context around the test function in question
sed -n '85,115p' testing/python/language/test_tilelang_language_reduce.pyRepository: tile-ai/tilelang
Length of output: 905
Guard torch.uint32 usage for torch-version compatibility.
torch.uint32 was introduced in PyTorch 2.3.0. Since the repository has no minimum PyTorch version requirement, this code will raise AttributeError at function definition time on PyTorch < 2.3.0, breaking tests even for other dtypes. Use a hasattr() check to build the dtype set conditionally, matching the pattern already used in the codebase for version-dependent dtypes like float8_e4m3fn.
🛠️ Suggested compatibility guard
+ _int_dtypes = {torch.int32, torch.int64}
+ if hasattr(torch, "uint32"):
+ _int_dtypes.add(torch.uint32)
+
def ref_fn(A):
res = A.sum(dim=1)
- if A.dtype in [torch.uint32, torch.int32, torch.int64]:
+ if A.dtype in _int_dtypes:
return res.to(A.dtype)
return res📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| import torch | |
| def ref_fn(A): | |
| res = A.sum(dim=1) | |
| if A.dtype in [torch.uint32, torch.int32, torch.int64]: | |
| return res.to(A.dtype) | |
| return res | |
| import torch | |
| _int_dtypes = {torch.int32, torch.int64} | |
| if hasattr(torch, "uint32"): | |
| _int_dtypes.add(torch.uint32) | |
| def ref_fn(A): | |
| res = A.sum(dim=1) | |
| if A.dtype in _int_dtypes: | |
| return res.to(A.dtype) | |
| return res |
🤖 Prompt for AI Agents
In `@testing/python/language/test_tilelang_language_reduce.py` around lines 99 -
105, The ref_fn uses torch.uint32 in a static dtype list which raises
AttributeError on PyTorch <2.3.0; update ref_fn to conditionally include
torch.uint32 only when hasattr(torch, "uint32") (or use getattr with a fallback)
so the dtype check is built at runtime, mirroring the existing pattern used for
version-dependent dtypes like float8_e4m3fn; locate and modify the dtype
membership test in ref_fn to construct the list/set conditionally and then
perform the same res.to(A.dtype) conversion for the supported integer dtypes
(torch.uint32, torch.int32, torch.int64).
|
Maybe we also need the performance regression tests. Where can I trigger them? @LeiWang1999 |
|
@regression-perf |
|
Thanks for your contribution! I've further investigated this problem this weekend and have some suggestions for the code:
|
|
@tzj-fxz Would you mind take a look. |
Thank you for the advice. I will further implement these features. |
src/op/reduce.cc
Outdated
| if (reducing_threads > 32) { | ||
| PrimExpr workspace = T.AddWorkspace( | ||
| *as_const_int(T.thread_bounds->extent), clear_buffer->dtype); | ||
| int workspace_size = (reducing_threads > 32 && |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It would be better to refactor this into a lambda function; that way, we can have cleaner code and more useful comments
| template <class Reducer, int threads, int scale, int thread_offset = 0, | ||
| int all_threads = threads> | ||
| struct AllReduce { | ||
| static_assert(threads == 1024 or threads == 512 or threads == 256 or |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why do we need to remove this piece of code?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is more flexible to support reduction across any warps (not constrained to powers of 2) by doing warp-level reduction after inner-warp reduction.
redux PTX supportThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
🤖 Fix all issues with AI agents
In `@src/tl_templates/cuda/reduce.h`:
- Around line 377-381: The fast-path incorrectly casts 64-bit integrals to
int32_t and uses run_reduce_sync, truncating results; update the compile-time
branch in the template so the run_reduce_sync path is only taken when T is
exactly int32_t or uint32_t (remove the branch that treats other
std::is_integral_v<T> types by casting), and let 64-bit types fall through to
the shfl_xor_sync-based fallback that correctly handles 64-bit operands; adjust
the condition around run_reduce_sync (the constexpr if that currently checks
std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t> and the following
std::is_integral_v<T> branch) accordingly so no cast-to-int32_t is performed for
64-bit integrals.
- Around line 344-358: The sm100 fast-path incorrectly casts floats and calls
__reduce_max_sync/__reduce_min_sync which only accept 32-bit integer operands;
modify the branch in reduce.h so the __reduce_*_sync calls are only compiled for
32-bit integer types (e.g., when std::is_same_v<T,int> or
std::is_same_v<T,unsigned int> / std::is_integral and sizeof(T)==4), and remove
the float/half/bfloat casting there; for all other types (float, half_t,
bfloat16_t, etc.) fall through to the existing non-sm100 handling. Ensure you
static_cast the value to int/unsigned before passing to
__reduce_max_sync/__reduce_min_sync and keep the ReduceOp/MaxOp/MinOp constexpr
checks to select the right intrinsic.
| #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && \ | ||
| (defined(__CUDA_ARCH_FEAT_SM100_ALL) || defined(__CUDA_ARCH_FEAT_SM100_F)) | ||
| float value_cast = 0.0f; | ||
| if constexpr (std::is_same_v<T, half_t>) { | ||
| value_cast = __half2float(value); | ||
| } else if constexpr (std::is_same_v<T, bfloat16_t>) { | ||
| value_cast = __bfloat162float(value); | ||
| } else { | ||
| value_cast = static_cast<float>(value); | ||
| } | ||
| if constexpr (std::is_same_v<ReduceOp, MaxOp>) { | ||
| return __reduce_max_sync(mask, value_cast); | ||
| } else if constexpr (std::is_same_v<ReduceOp, MinOp>) { | ||
| return __reduce_min_sync(mask, value_cast); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🌐 Web query:
CUDA __reduce_max_sync __reduce_min_sync operand types sm100 documentation
💡 Result:
For SM100 (compute capability 10.0), the CUDA warp reduce intrinsics __reduce_min_sync / __reduce_max_sync are documented with only 32-bit integer operand types:
unsigned __reduce_min_sync(unsigned mask, unsigned value);unsigned __reduce_max_sync(unsigned mask, unsigned value);int __reduce_min_sync(unsigned mask, int value);int __reduce_max_sync(unsigned mask, int value);[1]
CUDA’s programming guide also states that for {add, min, max} the reduced value type T “can be unsigned or signed” (and these intrinsics are supported on devices of compute capability 8.x or higher, which includes sm_100). [1]
Implication: there are no documented float, half, bf16, or 64-bit (long long, uint64) overloads for __reduce_min_sync / __reduce_max_sync; only int and unsigned are specified. [1]
Sources
- NVIDIA CUDA C++ Programming Guide (CUDA Toolkit 12.9.1), “Warp Reduce Functions” (synopsis + type rules). [1]
The sm100 Max/Min fast-path uses unsupported operand types with __reduce_*_sync intrinsics.
The __reduce_max_sync and __reduce_min_sync intrinsics only support 32-bit integer types (int and unsigned), not floating-point types. The current code casts any T to float and passes it to these functions, causing a type mismatch. Additionally, the proposed guard to floating types won't work because float, half_t, and bfloat16_t are unsupported by these intrinsics. Instead, this fast-path should apply only to int/unsigned types; all other types should fall through to alternative handling.
🔧 Proposed fix
`#if` defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && \
(defined(__CUDA_ARCH_FEAT_SM100_ALL) || defined(__CUDA_ARCH_FEAT_SM100_F))
- float value_cast = 0.0f;
- if constexpr (std::is_same_v<T, half_t>) {
- value_cast = __half2float(value);
- } else if constexpr (std::is_same_v<T, bfloat16_t>) {
- value_cast = __bfloat162float(value);
- } else {
- value_cast = static_cast<float>(value);
- }
- if constexpr (std::is_same_v<ReduceOp, MaxOp>) {
- return __reduce_max_sync(mask, value_cast);
- } else if constexpr (std::is_same_v<ReduceOp, MinOp>) {
- return __reduce_min_sync(mask, value_cast);
- }
+ if constexpr (std::is_same_v<T, int> || std::is_same_v<T, unsigned>) {
+ if constexpr (std::is_same_v<ReduceOp, MaxOp>) {
+ return __reduce_max_sync(mask, value);
+ } else if constexpr (std::is_same_v<ReduceOp, MinOp>) {
+ return __reduce_min_sync(mask, value);
+ }
+ }
`#endif`📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && \ | |
| (defined(__CUDA_ARCH_FEAT_SM100_ALL) || defined(__CUDA_ARCH_FEAT_SM100_F)) | |
| float value_cast = 0.0f; | |
| if constexpr (std::is_same_v<T, half_t>) { | |
| value_cast = __half2float(value); | |
| } else if constexpr (std::is_same_v<T, bfloat16_t>) { | |
| value_cast = __bfloat162float(value); | |
| } else { | |
| value_cast = static_cast<float>(value); | |
| } | |
| if constexpr (std::is_same_v<ReduceOp, MaxOp>) { | |
| return __reduce_max_sync(mask, value_cast); | |
| } else if constexpr (std::is_same_v<ReduceOp, MinOp>) { | |
| return __reduce_min_sync(mask, value_cast); | |
| } | |
| `#if` defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && \ | |
| (defined(__CUDA_ARCH_FEAT_SM100_ALL) || defined(__CUDA_ARCH_FEAT_SM100_F)) | |
| if constexpr (std::is_same_v<T, int> || std::is_same_v<T, unsigned>) { | |
| if constexpr (std::is_same_v<ReduceOp, MaxOp>) { | |
| return __reduce_max_sync(mask, value); | |
| } else if constexpr (std::is_same_v<ReduceOp, MinOp>) { | |
| return __reduce_min_sync(mask, value); | |
| } | |
| } | |
| `#endif` |
🤖 Prompt for AI Agents
In `@src/tl_templates/cuda/reduce.h` around lines 344 - 358, The sm100 fast-path
incorrectly casts floats and calls __reduce_max_sync/__reduce_min_sync which
only accept 32-bit integer operands; modify the branch in reduce.h so the
__reduce_*_sync calls are only compiled for 32-bit integer types (e.g., when
std::is_same_v<T,int> or std::is_same_v<T,unsigned int> / std::is_integral and
sizeof(T)==4), and remove the float/half/bfloat casting there; for all other
types (float, half_t, bfloat16_t, etc.) fall through to the existing non-sm100
handling. Ensure you static_cast the value to int/unsigned before passing to
__reduce_max_sync/__reduce_min_sync and keep the ReduceOp/MaxOp/MinOp constexpr
checks to select the right intrinsic.
| if constexpr (std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>) { | ||
| return run_reduce_sync(value); | ||
| } else if constexpr (std::is_integral_v<T>) { | ||
| return static_cast<T>(run_reduce_sync(static_cast<int32_t>(value))); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🧩 Analysis chain
🏁 Script executed:
# First, let's understand the file structure and find the reduce.h file
find . -name "reduce.h" -type f | head -20Repository: tile-ai/tilelang
Length of output: 142
🏁 Script executed:
# Read the reduce.h file around the lines in question to understand context
wc -l src/tl_templates/cuda/reduce.hRepository: tile-ai/tilelang
Length of output: 94
🏁 Script executed:
# Get context around lines 377-381
sed -n '350,410p' src/tl_templates/cuda/reduce.hRepository: tile-ai/tilelang
Length of output: 2198
🏁 Script executed:
# Search for run_reduce_sync definition to understand what it does
rg "run_reduce_sync" --type cpp --type h -B 3 -A 10Repository: tile-ai/tilelang
Length of output: 2178
🌐 Web query:
CUDA __reduce_sync functions operand size constraints 32-bit 64-bit
💡 Result:
CUDA’s warp __reduce_*_sync(mask, value) intrinsics are limited to 32-bit integer operands:
- The only provided overloads are for
unsignedandint(both 32-bit). There are no 64-bit (long long/unsigned long long) overloads, so you can’t directly reduce 64-bit values with these intrinsics. [1] - If you need a warp reduction on 64-bit data, use a shuffle-based reduction (e.g.,
__shfl_xor_sync), which supportslong long/unsigned long longand generally moves 8 bytes per thread for 64-bit types. [1]
Sources:
[1] NVIDIA CUDA C++ Programming Guide, sections “10.21 Warp Reduce Functions” and “10.22 Warp Shuffle Functions”. (docs.nvidia.com)
Citations:
Avoid truncating 64‑bit integral reductions.
CUDA's __reduce_*_sync() intrinsics support only 32-bit integer operands. The current code casts int64_t and uint64_t to int32_t before reduction, truncating the upper 32 bits and corrupting results. Restrict the fast-path to 32-bit integrals and let 64-bit types fall through to the shfl_xor_sync-based fallback, which properly handles 64-bit operands.
🔧 Proposed fix
- if constexpr (std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>) {
- return run_reduce_sync(value);
- } else if constexpr (std::is_integral_v<T>) {
- return static_cast<T>(run_reduce_sync(static_cast<int32_t>(value)));
- }
+ if constexpr (std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>) {
+ return run_reduce_sync(value);
+ } else if constexpr (std::is_integral_v<T> && sizeof(T) <= 4) {
+ return static_cast<T>(run_reduce_sync(static_cast<int32_t>(value)));
+ }📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if constexpr (std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>) { | |
| return run_reduce_sync(value); | |
| } else if constexpr (std::is_integral_v<T>) { | |
| return static_cast<T>(run_reduce_sync(static_cast<int32_t>(value))); | |
| } | |
| if constexpr (std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>) { | |
| return run_reduce_sync(value); | |
| } else if constexpr (std::is_integral_v<T> && sizeof(T) <= 4) { | |
| return static_cast<T>(run_reduce_sync(static_cast<int32_t>(value))); | |
| } |
🤖 Prompt for AI Agents
In `@src/tl_templates/cuda/reduce.h` around lines 377 - 381, The fast-path
incorrectly casts 64-bit integrals to int32_t and uses run_reduce_sync,
truncating results; update the compile-time branch in the template so the
run_reduce_sync path is only taken when T is exactly int32_t or uint32_t (remove
the branch that treats other std::is_integral_v<T> types by casting), and let
64-bit types fall through to the shfl_xor_sync-based fallback that correctly
handles 64-bit operands; adjust the condition around run_reduce_sync (the
constexpr if that currently checks std::is_same_v<T, int32_t> ||
std::is_same_v<T, uint32_t> and the following std::is_integral_v<T> branch)
accordingly so no cast-to-int32_t is performed for 64-bit integrals.
|
As far as I know, It's indeed complex. It seems that |
For #1761
redux.syncPTX templates to support faster reduction on (u)int32 with__CUDA_ARCH__>=800Summary by CodeRabbit
Performance
Bug Fixes
Tests