Skip to content

Conversation

@tzj-fxz
Copy link
Contributor

@tzj-fxz tzj-fxz commented Jan 31, 2026

For #1761

  • Add hierarchical reduction from warp to block to reduce workspace size
  • Add redux.sync PTX templates to support faster reduction on (u)int32 with __CUDA_ARCH__>=800

Summary by CodeRabbit

  • Performance

    • More efficient reduce execution with dynamic workspace sizing and architecture-specific CUDA optimizations for faster, lower-memory reductions.
  • Bug Fixes

    • Improved handling and validation for integer reduce results to ensure correct casts and accuracy.
  • Tests

    • Added robust reference-based tests covering int32 reduce across multiple tensor shapes.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 31, 2026

📝 Walkthrough

Walkthrough

Adds a hierarchical CUDA reduction path and dynamic AllReduce workspace sizing: introduces warp_reduce and warp-aware multi-stage reductions in CUDA templates, and conditions workspace allocation in the AllReduce path; also expands integer reduce_sum tests using a Torch-backed reference.

Changes

Cohort / File(s) Summary
Core reduction runtime
src/op/reduce.cc
Conditional workspace sizing for AllReduce: compute workspace as reducing_threads/32 when >32 and divisible by 32 and scale==1; pass reduced workspace to AllReduce thread reduction instead of full extent.
CUDA reduction templates
src/tl_templates/cuda/reduce.h
Adds warp_reduce declaration and implementation; introduces warp-level and hierarchical reduction paths (warp-first reduction, per-warp buffering, final intra-group reduction), uses architecture-specific intrinsics and FP16/BF16 handling; retains shfl_xor_sync fallback.
Tests
testing/python/language/test_tilelang_language_reduce.py
Disables cache at startup, replaces simple sum check with Torch-backed reference handling integer dtypes, and adds test_reduce_sum_int32() exercising int32 reduce_sum across shapes.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~45 minutes

Possibly related issues

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

🐰 Hopping through warps with a twitchy nose,

Buffers gather where the shared memory grows.
Threads whisper sums in a tidy chain,
Workspace trimmed — the reductions reign! 🥕✨

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 11.11% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly and specifically describes the main changes: hierarchical reduction and warp reduction intrinsics support, which align with the core modifications across all changed files.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.

✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@testing/python/language/test_tilelang_language_reduce.py`:
- Around line 99-105: The ref_fn uses torch.uint32 in a static dtype list which
raises AttributeError on PyTorch <2.3.0; update ref_fn to conditionally include
torch.uint32 only when hasattr(torch, "uint32") (or use getattr with a fallback)
so the dtype check is built at runtime, mirroring the existing pattern used for
version-dependent dtypes like float8_e4m3fn; locate and modify the dtype
membership test in ref_fn to construct the list/set conditionally and then
perform the same res.to(A.dtype) conversion for the supported integer dtypes
(torch.uint32, torch.int32, torch.int64).
🧹 Nitpick comments (1)
testing/python/language/test_tilelang_language_reduce.py (1)

7-7: Consider scoping disable_cache() to avoid global test side effects.
This flips a process-wide cache flag; if other tests run in the same session, they inherit the disabled cache. If that’s not intended, wrap it in a fixture/context that re-enables after this module.

Comment on lines +99 to +105
import torch

def ref_fn(A):
res = A.sum(dim=1)
if A.dtype in [torch.uint32, torch.int32, torch.int64]:
return res.to(A.dtype)
return res
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

🧩 Analysis chain

🏁 Script executed:

# First, find the root-level configuration files to determine PyTorch version requirements
fd -e "pyproject.toml" -e "requirements.txt" -e "setup.py" -e "setup.cfg" | head -20

Repository: tile-ai/tilelang

Length of output: 42


🏁 Script executed:

# Search for PyTorch version requirements in the repository
rg -i "torch.*version|pytorch" --type toml --type txt --type py -A 2 | head -50

Repository: tile-ai/tilelang

Length of output: 2715


🏁 Script executed:

# Check the specific test file and surrounding context
wc -l testing/python/language/test_tilelang_language_reduce.py

Repository: tile-ai/tilelang

Length of output: 120


🏁 Script executed:

# Look for other torch.uint32 usage in the codebase
rg "torch\.uint32" -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 3742


🏁 Script executed:

# Check for existing dtype guards or compatibility patterns
rg "hasattr.*torch" -B 2 -A 2

Repository: tile-ai/tilelang

Length of output: 6548


🏁 Script executed:

# Look for version requirements
cat pyproject.toml 2>/dev/null | head -100

Repository: tile-ai/tilelang

Length of output: 3126


🏁 Script executed:

# Check setup.py or requirements files
find . -maxdepth 2 -name "requirements*.txt" -o -name "setup.py" | xargs cat 2>/dev/null | grep -i torch

Repository: tile-ai/tilelang

Length of output: 130


🌐 Web query:

When was torch.uint32 introduced in PyTorch - version history

💡 Result:

torch.uint32 was introduced in PyTorch 2.3.0 (released April 24, 2024) as an unsigned 32-bit “shell dtype” (available as a dtype object, but with limited eager-mode operator support). [1][2][3]

Before that, PyTorch generally did not support uint32/uint64 (beyond uint8). [4]

Sources:
[1] (dev-discuss.pytorch.org)
[2] (github.com)
[3] (docs.pytorch.org)
[4] (discuss.pytorch.org)

Citations:


🏁 Script executed:

# Look at the context around the test function in question
sed -n '85,115p' testing/python/language/test_tilelang_language_reduce.py

Repository: tile-ai/tilelang

Length of output: 905


Guard torch.uint32 usage for torch-version compatibility.
torch.uint32 was introduced in PyTorch 2.3.0. Since the repository has no minimum PyTorch version requirement, this code will raise AttributeError at function definition time on PyTorch < 2.3.0, breaking tests even for other dtypes. Use a hasattr() check to build the dtype set conditionally, matching the pattern already used in the codebase for version-dependent dtypes like float8_e4m3fn.

🛠️ Suggested compatibility guard
+    _int_dtypes = {torch.int32, torch.int64}
+    if hasattr(torch, "uint32"):
+        _int_dtypes.add(torch.uint32)
+
     def ref_fn(A):
         res = A.sum(dim=1)
-        if A.dtype in [torch.uint32, torch.int32, torch.int64]:
+        if A.dtype in _int_dtypes:
             return res.to(A.dtype)
         return res
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import torch
def ref_fn(A):
res = A.sum(dim=1)
if A.dtype in [torch.uint32, torch.int32, torch.int64]:
return res.to(A.dtype)
return res
import torch
_int_dtypes = {torch.int32, torch.int64}
if hasattr(torch, "uint32"):
_int_dtypes.add(torch.uint32)
def ref_fn(A):
res = A.sum(dim=1)
if A.dtype in _int_dtypes:
return res.to(A.dtype)
return res
🤖 Prompt for AI Agents
In `@testing/python/language/test_tilelang_language_reduce.py` around lines 99 -
105, The ref_fn uses torch.uint32 in a static dtype list which raises
AttributeError on PyTorch <2.3.0; update ref_fn to conditionally include
torch.uint32 only when hasattr(torch, "uint32") (or use getattr with a fallback)
so the dtype check is built at runtime, mirroring the existing pattern used for
version-dependent dtypes like float8_e4m3fn; locate and modify the dtype
membership test in ref_fn to construct the list/set conditionally and then
perform the same res.to(A.dtype) conversion for the supported integer dtypes
(torch.uint32, torch.int32, torch.int64).

@tzj-fxz
Copy link
Contributor Author

tzj-fxz commented Jan 31, 2026

Maybe we also need the performance regression tests. Where can I trigger them? @LeiWang1999

@Rachmanino
Copy link
Collaborator

@regression-perf

@bucket-xv
Copy link

Thanks for your contribution! I've further investigated this problem this weekend and have some suggestions for the code:

  1. Use intrinsic functions instead of raw redux PTX for maintainability. Refer to https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-reduce-functions. Note that this is also for sm80+.
  2. May leverage redux instruction for f32 types? This is supported since sm100a. Refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-redux-sync.
  3. At warp level reduction, maybe any floating types can be cast to f32 and integral types cast to int32. This cast helps leverage the redux inst. This is almost always better in performance since there is also an implicit cast for shfl.sync, which requires b32 types.(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-shfl-sync)

@LeiWang1999
Copy link
Member

@tzj-fxz Would you mind take a look.

@tzj-fxz
Copy link
Contributor Author

tzj-fxz commented Feb 2, 2026

Thanks for your contribution! I've further investigated this problem this weekend and have some suggestions for the code:

  1. Use intrinsic functions instead of raw redux PTX for maintainability. Refer to https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#warp-reduce-functions. Note that this is also for sm80+.
  2. May leverage redux instruction for f32 types? This is supported since sm100a. Refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-redux-sync.
  3. At warp level reduction, maybe any floating types can be cast to f32 and integral types cast to int32. This cast helps leverage the redux inst. This is almost always better in performance since there is also an implicit cast for shfl.sync, which requires b32 types.(https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-shfl-sync)

Thank you for the advice. I will further implement these features.

src/op/reduce.cc Outdated
if (reducing_threads > 32) {
PrimExpr workspace = T.AddWorkspace(
*as_const_int(T.thread_bounds->extent), clear_buffer->dtype);
int workspace_size = (reducing_threads > 32 &&
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be better to refactor this into a lambda function; that way, we can have cleaner code and more useful comments

template <class Reducer, int threads, int scale, int thread_offset = 0,
int all_threads = threads>
struct AllReduce {
static_assert(threads == 1024 or threads == 512 or threads == 256 or
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why do we need to remove this piece of code?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is more flexible to support reduction across any warps (not constrained to powers of 2) by doing warp-level reduction after inner-warp reduction.

@tzj-fxz tzj-fxz changed the title [Feature] Hierarchical reduction and redux PTX support [Feature] Hierarchical reduction and warp reduction intrinsics support Feb 2, 2026
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

🤖 Fix all issues with AI agents
In `@src/tl_templates/cuda/reduce.h`:
- Around line 377-381: The fast-path incorrectly casts 64-bit integrals to
int32_t and uses run_reduce_sync, truncating results; update the compile-time
branch in the template so the run_reduce_sync path is only taken when T is
exactly int32_t or uint32_t (remove the branch that treats other
std::is_integral_v<T> types by casting), and let 64-bit types fall through to
the shfl_xor_sync-based fallback that correctly handles 64-bit operands; adjust
the condition around run_reduce_sync (the constexpr if that currently checks
std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t> and the following
std::is_integral_v<T> branch) accordingly so no cast-to-int32_t is performed for
64-bit integrals.
- Around line 344-358: The sm100 fast-path incorrectly casts floats and calls
__reduce_max_sync/__reduce_min_sync which only accept 32-bit integer operands;
modify the branch in reduce.h so the __reduce_*_sync calls are only compiled for
32-bit integer types (e.g., when std::is_same_v<T,int> or
std::is_same_v<T,unsigned int> / std::is_integral and sizeof(T)==4), and remove
the float/half/bfloat casting there; for all other types (float, half_t,
bfloat16_t, etc.) fall through to the existing non-sm100 handling. Ensure you
static_cast the value to int/unsigned before passing to
__reduce_max_sync/__reduce_min_sync and keep the ReduceOp/MaxOp/MinOp constexpr
checks to select the right intrinsic.

Comment on lines +344 to +358
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && \
(defined(__CUDA_ARCH_FEAT_SM100_ALL) || defined(__CUDA_ARCH_FEAT_SM100_F))
float value_cast = 0.0f;
if constexpr (std::is_same_v<T, half_t>) {
value_cast = __half2float(value);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
value_cast = __bfloat162float(value);
} else {
value_cast = static_cast<float>(value);
}
if constexpr (std::is_same_v<ReduceOp, MaxOp>) {
return __reduce_max_sync(mask, value_cast);
} else if constexpr (std::is_same_v<ReduceOp, MinOp>) {
return __reduce_min_sync(mask, value_cast);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

🧩 Analysis chain

🌐 Web query:

CUDA __reduce_max_sync __reduce_min_sync operand types sm100 documentation

💡 Result:

For SM100 (compute capability 10.0), the CUDA warp reduce intrinsics __reduce_min_sync / __reduce_max_sync are documented with only 32-bit integer operand types:

  • unsigned __reduce_min_sync(unsigned mask, unsigned value);
  • unsigned __reduce_max_sync(unsigned mask, unsigned value);
  • int __reduce_min_sync(unsigned mask, int value);
  • int __reduce_max_sync(unsigned mask, int value);  [1]

CUDA’s programming guide also states that for {add, min, max} the reduced value type T “can be unsigned or signed” (and these intrinsics are supported on devices of compute capability 8.x or higher, which includes sm_100). [1]

Implication: there are no documented float, half, bf16, or 64-bit (long long, uint64) overloads for __reduce_min_sync / __reduce_max_sync; only int and unsigned are specified. [1]

Sources

  1. NVIDIA CUDA C++ Programming Guide (CUDA Toolkit 12.9.1), “Warp Reduce Functions” (synopsis + type rules). [1]

The sm100 Max/Min fast-path uses unsupported operand types with __reduce_*_sync intrinsics.

The __reduce_max_sync and __reduce_min_sync intrinsics only support 32-bit integer types (int and unsigned), not floating-point types. The current code casts any T to float and passes it to these functions, causing a type mismatch. Additionally, the proposed guard to floating types won't work because float, half_t, and bfloat16_t are unsupported by these intrinsics. Instead, this fast-path should apply only to int/unsigned types; all other types should fall through to alternative handling.

🔧 Proposed fix
 `#if` defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) &&                       \
     (defined(__CUDA_ARCH_FEAT_SM100_ALL) || defined(__CUDA_ARCH_FEAT_SM100_F))
-  float value_cast = 0.0f;
-  if constexpr (std::is_same_v<T, half_t>) {
-    value_cast = __half2float(value);
-  } else if constexpr (std::is_same_v<T, bfloat16_t>) {
-    value_cast = __bfloat162float(value);
-  } else {
-    value_cast = static_cast<float>(value);
-  }
-  if constexpr (std::is_same_v<ReduceOp, MaxOp>) {
-    return __reduce_max_sync(mask, value_cast);
-  } else if constexpr (std::is_same_v<ReduceOp, MinOp>) {
-    return __reduce_min_sync(mask, value_cast);
-  }
+  if constexpr (std::is_same_v<T, int> || std::is_same_v<T, unsigned>) {
+    if constexpr (std::is_same_v<ReduceOp, MaxOp>) {
+      return __reduce_max_sync(mask, value);
+    } else if constexpr (std::is_same_v<ReduceOp, MinOp>) {
+      return __reduce_min_sync(mask, value);
+    }
+  }
 `#endif`
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && \
(defined(__CUDA_ARCH_FEAT_SM100_ALL) || defined(__CUDA_ARCH_FEAT_SM100_F))
float value_cast = 0.0f;
if constexpr (std::is_same_v<T, half_t>) {
value_cast = __half2float(value);
} else if constexpr (std::is_same_v<T, bfloat16_t>) {
value_cast = __bfloat162float(value);
} else {
value_cast = static_cast<float>(value);
}
if constexpr (std::is_same_v<ReduceOp, MaxOp>) {
return __reduce_max_sync(mask, value_cast);
} else if constexpr (std::is_same_v<ReduceOp, MinOp>) {
return __reduce_min_sync(mask, value_cast);
}
`#if` defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) && \
(defined(__CUDA_ARCH_FEAT_SM100_ALL) || defined(__CUDA_ARCH_FEAT_SM100_F))
if constexpr (std::is_same_v<T, int> || std::is_same_v<T, unsigned>) {
if constexpr (std::is_same_v<ReduceOp, MaxOp>) {
return __reduce_max_sync(mask, value);
} else if constexpr (std::is_same_v<ReduceOp, MinOp>) {
return __reduce_min_sync(mask, value);
}
}
`#endif`
🤖 Prompt for AI Agents
In `@src/tl_templates/cuda/reduce.h` around lines 344 - 358, The sm100 fast-path
incorrectly casts floats and calls __reduce_max_sync/__reduce_min_sync which
only accept 32-bit integer operands; modify the branch in reduce.h so the
__reduce_*_sync calls are only compiled for 32-bit integer types (e.g., when
std::is_same_v<T,int> or std::is_same_v<T,unsigned int> / std::is_integral and
sizeof(T)==4), and remove the float/half/bfloat casting there; for all other
types (float, half_t, bfloat16_t, etc.) fall through to the existing non-sm100
handling. Ensure you static_cast the value to int/unsigned before passing to
__reduce_max_sync/__reduce_min_sync and keep the ReduceOp/MaxOp/MinOp constexpr
checks to select the right intrinsic.

Comment on lines +377 to +381
if constexpr (std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>) {
return run_reduce_sync(value);
} else if constexpr (std::is_integral_v<T>) {
return static_cast<T>(run_reduce_sync(static_cast<int32_t>(value)));
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

🧩 Analysis chain

🏁 Script executed:

# First, let's understand the file structure and find the reduce.h file
find . -name "reduce.h" -type f | head -20

Repository: tile-ai/tilelang

Length of output: 142


🏁 Script executed:

# Read the reduce.h file around the lines in question to understand context
wc -l src/tl_templates/cuda/reduce.h

Repository: tile-ai/tilelang

Length of output: 94


🏁 Script executed:

# Get context around lines 377-381
sed -n '350,410p' src/tl_templates/cuda/reduce.h

Repository: tile-ai/tilelang

Length of output: 2198


🏁 Script executed:

# Search for run_reduce_sync definition to understand what it does
rg "run_reduce_sync" --type cpp --type h -B 3 -A 10

Repository: tile-ai/tilelang

Length of output: 2178


🌐 Web query:

CUDA __reduce_sync functions operand size constraints 32-bit 64-bit

💡 Result:

CUDA’s warp __reduce_*_sync(mask, value) intrinsics are limited to 32-bit integer operands:

  • The only provided overloads are for unsigned and int (both 32-bit). There are no 64-bit (long long / unsigned long long) overloads, so you can’t directly reduce 64-bit values with these intrinsics. [1]
  • If you need a warp reduction on 64-bit data, use a shuffle-based reduction (e.g., __shfl_xor_sync), which supports long long / unsigned long long and generally moves 8 bytes per thread for 64-bit types. [1]

Sources:
[1] NVIDIA CUDA C++ Programming Guide, sections “10.21 Warp Reduce Functions” and “10.22 Warp Shuffle Functions”. (docs.nvidia.com)

Citations:


Avoid truncating 64‑bit integral reductions.

CUDA's __reduce_*_sync() intrinsics support only 32-bit integer operands. The current code casts int64_t and uint64_t to int32_t before reduction, truncating the upper 32 bits and corrupting results. Restrict the fast-path to 32-bit integrals and let 64-bit types fall through to the shfl_xor_sync-based fallback, which properly handles 64-bit operands.

🔧 Proposed fix
-  if constexpr (std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>) {
-    return run_reduce_sync(value);
-  } else if constexpr (std::is_integral_v<T>) {
-    return static_cast<T>(run_reduce_sync(static_cast<int32_t>(value)));
-  }
+  if constexpr (std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>) {
+    return run_reduce_sync(value);
+  } else if constexpr (std::is_integral_v<T> && sizeof(T) <= 4) {
+    return static_cast<T>(run_reduce_sync(static_cast<int32_t>(value)));
+  }
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if constexpr (std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>) {
return run_reduce_sync(value);
} else if constexpr (std::is_integral_v<T>) {
return static_cast<T>(run_reduce_sync(static_cast<int32_t>(value)));
}
if constexpr (std::is_same_v<T, int32_t> || std::is_same_v<T, uint32_t>) {
return run_reduce_sync(value);
} else if constexpr (std::is_integral_v<T> && sizeof(T) <= 4) {
return static_cast<T>(run_reduce_sync(static_cast<int32_t>(value)));
}
🤖 Prompt for AI Agents
In `@src/tl_templates/cuda/reduce.h` around lines 377 - 381, The fast-path
incorrectly casts 64-bit integrals to int32_t and uses run_reduce_sync,
truncating results; update the compile-time branch in the template so the
run_reduce_sync path is only taken when T is exactly int32_t or uint32_t (remove
the branch that treats other std::is_integral_v<T> types by casting), and let
64-bit types fall through to the shfl_xor_sync-based fallback that correctly
handles 64-bit operands; adjust the condition around run_reduce_sync (the
constexpr if that currently checks std::is_same_v<T, int32_t> ||
std::is_same_v<T, uint32_t> and the following std::is_integral_v<T> branch)
accordingly so no cast-to-int32_t is performed for 64-bit integrals.

@bucket-xv
Copy link

bucket-xv commented Feb 3, 2026

As far as I know, __reduce_max_sync does not accept float types. May you use assembly to use redux.sync for floating types since sm_100a?

It's indeed complex. It seems that __reduce_max_sync can be used to substitute ASM, but no function warps redux.sync for floating types since sm_100a.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants