-
Notifications
You must be signed in to change notification settings - Fork 430
Description
Required prerequisites
- I have read the documentation https://tilelang.com.
- I have searched the Issue Tracker that this hasn't already been reported. (comment there if it has.)
What version of TileLang are you using?
0.1.7.post1
System information
3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0] linux
0.1.7.post1
2.9.0+cu128
:128: RuntimeWarning: 'torch.utils.collect_env' found in sys.modules after import of package 'torch.utils', but prior to execution of 'torch.utils.collect_env'; this may result in unpredictable behaviour
Collecting environment information...
PyTorch version: 2.9.0+cu128
Is debug build: False
CUDA used to build PyTorch: 12.8
ROCM used to build PyTorch: N/A
OS: Ubuntu 22.04.4 LTS (x86_64)
GCC version: (Ubuntu 11.4.0-1ubuntu1~22.04) 11.4.0
Clang version: Could not collect
CMake version: version 3.22.1
Libc version: glibc-2.35
Python version: 3.12.12 | packaged by Anaconda, Inc. | (main, Oct 21 2025, 20:16:04) [GCC 11.2.0] (64-bit runtime)
Python platform: Linux-5.15.0-94-generic-x86_64-with-glibc2.35
Is CUDA available: True
CUDA runtime version: Could not collect
CUDA_MODULE_LOADING set to:
GPU models and configuration:
GPU 0: NVIDIA H100 80GB HBM3
GPU 1: NVIDIA H100 80GB HBM3
GPU 2: NVIDIA H100 80GB HBM3
GPU 3: NVIDIA H100 80GB HBM3
GPU 4: NVIDIA H100 80GB HBM3
GPU 5: NVIDIA H100 80GB HBM3
GPU 6: NVIDIA H100 80GB HBM3
GPU 7: NVIDIA H100 80GB HBM3
Nvidia driver version: 580.82.07
cuDNN version: Could not collect
Is XPU available: False
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Architecture: x86_64
CPU op-mode(s): 32-bit, 64-bit
Address sizes: 52 bits physical, 57 bits virtual
Byte Order: Little Endian
CPU(s): 96
On-line CPU(s) list: 0-95
Vendor ID: GenuineIntel
Model name: Intel(R) Xeon(R) Platinum 8463B
CPU family: 6
Model: 143
Thread(s) per core: 1
Core(s) per socket: 48
Socket(s): 2
Stepping: 8
Frequency boost: enabled
CPU max MHz: 2601.0000
CPU min MHz: 800.0000
BogoMIPS: 5200.00
Flags: fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush dts acpi mmx fxsr sse sse2 ss ht tm pbe syscall nx pdpe1gb rdtscp lm constant_tsc art arch_perfmon pebs bts rep_good nopl xtopology nonstop_tsc cpuid aperfmperf tsc_known_freq pni pclmulqdq dtes64 ds_cpl vmx smx est tm2 ssse3 sdbg fma cx16 xtpr pdcm pcid dca sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand lahf_lm abm 3dnowprefetch cpuid_fault epb cat_l3 cat_l2 cdp_l3 invpcid_single intel_ppin cdp_l2 ssbd mba ibrs ibpb stibp ibrs_enhanced tpr_shadow vnmi flexpriority ept vpid ept_ad fsgsbase tsc_adjust bmi1 avx2 smep bmi2 erms invpcid cqm rdt_a avx512f avx512dq rdseed adx smap avx512ifma clflushopt clwb intel_pt avx512cd sha_ni avx512bw avx512vl xsaveopt xsavec xgetbv1 xsaves cqm_llc cqm_occup_llc cqm_mbm_total cqm_mbm_local split_lock_detect avx_vnni avx512_bf16 wbnoinvd dtherm ida arat pln pts avx512vbmi umip pku ospke waitpkg avx512_vbmi2 gfni vaes vpclmulqdq avx512_vnni avx512_bitalg tme avx512_vpopcntdq la57 rdpid bus_lock_detect cldemote movdiri movdir64b enqcmd fsrm md_clear serialize tsxldtrk pconfig arch_lbr amx_bf16 avx512_fp16 amx_tile amx_int8 flush_l1d arch_capabilities
Virtualization: VT-x
L1d cache: 4.5 MiB (96 instances)
L1i cache: 3 MiB (96 instances)
L2 cache: 192 MiB (96 instances)
L3 cache: 210 MiB (2 instances)
NUMA node(s): 2
NUMA node0 CPU(s): 0-47
NUMA node1 CPU(s): 48-95
Vulnerability Gather data sampling: Not affected
Vulnerability Itlb multihit: Not affected
Vulnerability L1tf: Not affected
Vulnerability Mds: Not affected
Vulnerability Meltdown: Not affected
Vulnerability Mmio stale data: Not affected
Vulnerability Retbleed: Not affected
Vulnerability Spec rstack overflow: Not affected
Vulnerability Spec store bypass: Mitigation; Speculative Store Bypass disabled via prctl and seccomp
Vulnerability Spectre v1: Mitigation; usercopy/swapgs barriers and __user pointer sanitization
Vulnerability Spectre v2: Mitigation; Enhanced IBRS, IBPB conditional, RSB filling, PBRSB-eIBRS SW sequence
Vulnerability Srbds: Not affected
Vulnerability Tsx async abort: Not affected
Versions of relevant libraries:
[pip3] Could not collect
[conda] numpy 1.26.4 pypi_0 pypi
[conda] nvidia-cublas-cu12 12.8.4.1 pypi_0 pypi
[conda] nvidia-cuda-cupti-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cuda-nvrtc-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-cuda-runtime-cu12 12.8.90 pypi_0 pypi
[conda] nvidia-cudnn-cu12 9.10.2.21 pypi_0 pypi
[conda] nvidia-cufft-cu12 11.3.3.83 pypi_0 pypi
[conda] nvidia-curand-cu12 10.3.9.90 pypi_0 pypi
[conda] nvidia-cusolver-cu12 11.7.3.90 pypi_0 pypi
[conda] nvidia-cusparse-cu12 12.5.8.93 pypi_0 pypi
[conda] nvidia-cusparselt-cu12 0.7.1 pypi_0 pypi
[conda] nvidia-nccl-cu12 2.27.5 pypi_0 pypi
[conda] nvidia-nvjitlink-cu12 12.8.93 pypi_0 pypi
[conda] nvidia-nvtx-cu12 12.8.90 pypi_0 pypi
[conda] torch 2.9.0 pypi_0 pypi
[conda] torchvision 0.24.0 pypi_0 pypi
[conda] triton 3.5.0 pypi_0 pypi
Problem description
see code below:
if version1:
T.reduce_sum(Block_count, Row_Sums, dim=1, clear=True)
for i, j in T.Parallel(valid_rows, valid_cols):
if i == 0 and j == 0:
T.print(Row_Sums, "Row_Sums")
T.reduce_sum(Row_Sums, Total_Sum, dim=0, clear=True)
for i, j in T.Parallel(valid_rows, valid_cols):
if i == 0 and j == 0:
T.print(Total_Sum[0], "Total_Sum")
T.print(valid_rows, "valid_rows")
T.print(valid_cols, "valid_cols")
else:
T.reduce_sum(Block_count, Row_Sums, clear=True)
T.reduce_sum(Row_Sums, Total_Sum, dim=0, clear=True)
for i, j in T.Parallel(valid_rows, valid_cols):
if i == 0 and j == 0:
T.print(Total_Sum[0], "Total_Sum")
T.print(valid_rows, "valid_rows")
T.print(valid_cols, "valid_cols")same logic, different print position, wrong res
Reproducible example code
The Python snippets:
import torch
from torch.nn.attention.flex_attention import create_block_mask
import tilelang.language as T
import tilelang
@tilelang.jit()
def get_block_mask_kernel(block_major, block_minor, transpose_mask, version1=False, threads=128):
@T.prim_func
def kernel(
Major_Ind: T.Tensor([T.dynamic("seq_len_major")], T.int32),
Minor_Ind: T.Tensor([T.dynamic("seq_len_minor")], T.int32),
Mask: T.Tensor([T.dynamic("M"), T.dynamic("N")], T.bool),
Cnt_Partial: T.Tensor([1, 1, T.dynamic("n_major_blocks")], T.int32),
Idx_Partial: T.Tensor([1, 1, T.dynamic("n_major_blocks"), T.dynamic("n_minor_blocks")], T.int32),
Cnt_Full: T.Tensor([1, 1, T.dynamic("n_major_blocks")], T.int32),
Idx_Full: T.Tensor([1, 1, T.dynamic("n_major_blocks"), T.dynamic("n_minor_blocks")], T.int32),
):
seq_len_major = Major_Ind.shape[0]
seq_len_minor = Minor_Ind.shape[0]
num_major_blocks = T.ceildiv(seq_len_major, block_major)
num_minor_blocks = T.ceildiv(seq_len_minor, block_minor)
with T.Kernel(num_major_blocks, threads=threads) as bx:
Major_Shared = T.alloc_shared([block_major], T.int32)
Minor_Shared = T.alloc_shared([block_minor], T.int32)
Block_count = T.alloc_shared([block_major, block_minor], T.int32)
Row_Sums = T.alloc_shared([block_major], T.int32)
Total_Sum = T.alloc_shared([1], T.int32)
T.copy(Major_Ind[bx * block_major : (bx + 1) * block_major], Major_Shared)
valid_rows = T.min(block_major, seq_len_major - bx * block_major)
for by in T.Pipelined(num_minor_blocks):
T.clear(Block_count)
T.clear(Row_Sums)
T.copy(Minor_Ind[by * block_minor : (by + 1) * block_minor], Minor_Shared)
valid_cols = T.min(block_minor, seq_len_minor - by * block_minor)
if transpose_mask:
for i, j in T.Parallel(valid_rows, valid_cols):
maj_idx = Major_Shared[i]
min_idx = Minor_Shared[j]
Block_count[i, j] = T.if_then_else(Mask[min_idx, maj_idx], 1, 0)
else:
for i, j in T.Parallel(valid_rows, valid_cols):
maj_idx = Major_Shared[i]
min_idx = Minor_Shared[j]
Block_count[i, j] = T.if_then_else(Mask[maj_idx, min_idx], 1, 0)
if version1:
T.reduce_sum(Block_count, Row_Sums, dim=1, clear=True)
for i, j in T.Parallel(valid_rows, valid_cols):
if i == 0 and j == 0:
T.print(Row_Sums, "Row_Sums")
T.reduce_sum(Row_Sums, Total_Sum, dim=0, clear=True)
for i, j in T.Parallel(valid_rows, valid_cols):
if i == 0 and j == 0:
T.print(Total_Sum[0], "Total_Sum")
T.print(valid_rows, "valid_rows")
T.print(valid_cols, "valid_cols")
else:
T.reduce_sum(Block_count, Row_Sums, clear=True)
T.reduce_sum(Row_Sums, Total_Sum, dim=0, clear=True)
for i, j in T.Parallel(valid_rows, valid_cols):
if i == 0 and j == 0:
T.print(Total_Sum[0], "Total_Sum")
T.print(valid_rows, "valid_rows")
T.print(valid_cols, "valid_cols")
return kernel
def get_varlen_block_mask(q_block_indices, k_block_indices, logical_block_mask, BLOCK_M, BLOCK_N, version1=False):
"""
q_block_indices: [seq_len_q]
k_block_indices: [seq_len_kv]
logical_block_mask: [logical_block_mask_m, logical_block_mask_n]
BLOCK_M: int, q block size
BLOCK_N: int, kv block size
forward: bool
return:
if forward:
kv_num_blocks: [1, 1, num_hw_q_blocks], describe the number of no full(need check by mask mod) kv blocks for each hw q block
kv_indices: [1, 1, num_hw_q_blocks, num_hw_kv_blocks], describe the indices of no full(need check by mask mod) kv blocks for each hw q block, the first kv_num_blocks[0, 0, bx] kv indices are valid, the rest are invalid
full_kv_num_blocks: [1, 1, num_hw_q_blocks], describe the number of full kv blocks for each hw q block
full_kv_indices: [1, 1, num_hw_q_blocks, num_hw_kv_blocks], describe the indices of full kv blocks for each hw q block, the first full_kv_num_blocks[0, 0, bx] kv indices are valid, the rest are invalid
else:
q_num_blocks: [1, 1, num_hw_kv_blocks], describe the number of no full(need check by mask mod) q blocks for each hw kv block
q_indices: [1, 1, num_hw_kv_blocks, num_hw_q_blocks], describe the indices of no full(need check by mask mod) q blocks for each hw kv block, the first q_num_blocks[0, 0, bx] q indices are valid, the rest are invalid
full_q_num_blocks: [1, 1, num_hw_kv_blocks], describe the number of full q blocks for each hw kv block
full_q_indices: [1, 1, num_hw_kv_blocks, num_hw_q_blocks], describe the indices of full q blocks for each hw kv block, the first full_q_num_blocks[0, 0, bx] q indices are valid, the rest are invalid
"""
device = q_block_indices.device
seq_len_q = q_block_indices.shape[0]
seq_len_kv = k_block_indices.shape[0]
num_hw_q_blocks = (seq_len_q + BLOCK_M - 1) // BLOCK_M
num_hw_kv_blocks = (seq_len_kv + BLOCK_N - 1) // BLOCK_N
kv_num_blocks = torch.zeros(1, 1, num_hw_q_blocks, dtype=torch.int32, device=device)
kv_indices = torch.zeros(1, 1, num_hw_q_blocks, num_hw_kv_blocks, dtype=torch.int32, device=device)
full_kv_num_blocks = torch.zeros(1, 1, num_hw_q_blocks, dtype=torch.int32, device=device)
full_kv_indices = torch.zeros(1, 1, num_hw_q_blocks, num_hw_kv_blocks, dtype=torch.int32, device=device)
kernel = get_block_mask_kernel(BLOCK_M, BLOCK_N, transpose_mask=False, version1=version1)
kernel(q_block_indices, k_block_indices, logical_block_mask, kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices)
return kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices
def test_block_mask_forward(version1=False):
torch.manual_seed(42)
device = "cuda" if torch.cuda.is_available() else "cpu"
N_CTX_Q = 128
N_CTX_KV = 128
# Block sizes for Forward: M=Q, N=KV
block_M = 64
block_N = 64
num_q_blocks = (N_CTX_Q + block_M - 1) // block_M
num_kv_blocks = (N_CTX_KV + block_N - 1) // block_N
q_block_indices = torch.arange(num_q_blocks, dtype=torch.int32, device=device).repeat_interleave(block_M)[:N_CTX_Q]
k_block_indices = torch.arange(num_kv_blocks, dtype=torch.int32, device=device).repeat_interleave(block_N)[:N_CTX_KV]
logical_block_mask = torch.randint(0, 2, (num_q_blocks, num_kv_blocks), dtype=torch.bool, device=device)
print(logical_block_mask)
# Reference using create_block_mask
def block_mask_mod(b, h, q_idx, kv_idx):
bqi = q_block_indices[q_idx]
bki = k_block_indices[kv_idx]
return logical_block_mask[bqi, bki]
# For create_block_mask: BLOCK_SIZE=(Q_SIZE, KV_SIZE)
_block_mask = create_block_mask(
block_mask_mod,
None,
None,
Q_LEN=N_CTX_Q,
KV_LEN=N_CTX_KV,
device=device,
BLOCK_SIZE=(block_M, block_N),
_compile=False
)
# Run our implementation
# forward=True: input BLOCK_M=Q_SIZE, BLOCK_N=KV_SIZE
kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices = get_varlen_block_mask(
q_block_indices,
k_block_indices,
logical_block_mask,
BLOCK_M=block_M,
BLOCK_N=block_N,
version1=version1
)
print(f"{_block_mask.kv_num_blocks.shape=}")
print(f"{kv_num_blocks.shape=}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--version1", action="store_true")
args = parser.parse_args()
print(f"{args.version1=}")
test_block_mask_forward(args.version1)Traceback
print the output:
python varlen-block-attention-dev/tests/test_block_mask_bug.py --version1
args.version1=True
tensor([[ True, True],
[False, False]], device='cuda:0')
_block_mask.kv_num_blocks.shape=torch.Size([1, 1, 2])
kv_num_blocks.shape=torch.Size([1, 1, 2])
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=0, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=0, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=1, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=1, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=2, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=2, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=3, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=3, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=4, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=4, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=5, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=5, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=6, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=6, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=7, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=7, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=8, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=8, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=9, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=9, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=10, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=10, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=11, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=11, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=12, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=12, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=13, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=13, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=14, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=14, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=15, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=15, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=16, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=16, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=17, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=17, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=18, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=18, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=19, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=19, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=20, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=20, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=21, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=21, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=22, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=22, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=23, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=23, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=24, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=24, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=25, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=25, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=26, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=26, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=27, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=27, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=28, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=28, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=29, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=29, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=30, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=30, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=31, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=31, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=32, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=32, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=33, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=33, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=34, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=34, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=35, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=35, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=36, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=36, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=37, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=37, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=38, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=38, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=39, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=39, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=40, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=40, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=41, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=41, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=42, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=42, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=43, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=43, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=44, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=44, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=45, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=45, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=46, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=46, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=47, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=47, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=48, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=48, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=49, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=49, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=50, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=50, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=51, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=51, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=52, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=52, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=53, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=53, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=54, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=54, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=55, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=55, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=56, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=56, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=57, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=57, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=58, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=58, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=59, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=59, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=60, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=60, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=61, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=61, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=62, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=62, dtype=int value=1
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=63, dtype=int value=1
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=63, dtype=int value=1
msg='Total_Sum' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
msg='Total_Sum' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
msg='valid_rows' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
msg='valid_rows' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
msg='valid_cols' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
msg='valid_cols' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=0, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=0, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=1, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=1, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=2, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=2, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=3, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=3, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=4, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=4, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=5, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=5, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=6, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=6, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=7, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=7, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=8, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=8, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=9, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=9, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=10, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=10, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=11, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=11, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=12, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=12, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=13, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=13, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=14, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=14, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=15, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=15, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=16, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=16, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=17, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=17, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=18, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=18, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=19, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=19, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=20, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=20, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=21, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=21, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=22, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=22, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=23, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=23, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=24, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=24, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=25, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=25, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=26, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=26, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=27, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=27, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=28, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=28, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=29, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=29, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=30, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=30, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=31, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=31, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=32, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=32, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=33, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=33, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=34, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=34, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=35, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=35, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=36, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=36, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=37, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=37, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=38, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=38, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=39, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=39, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=40, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=40, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=41, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=41, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=42, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=42, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=43, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=43, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=44, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=44, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=45, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=45, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=46, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=46, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=47, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=47, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=48, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=48, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=49, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=49, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=50, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=50, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=51, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=51, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=52, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=52, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=53, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=53, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=54, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=54, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=55, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=55, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=56, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=56, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=57, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=57, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=58, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=58, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=59, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=59, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=60, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=60, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=61, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=61, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=62, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=62, dtype=int value=0
msg='Row_Sums' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=63, dtype=int value=0
msg='Row_Sums' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): buffer=Row_Sums, index=63, dtype=int value=0
msg='Total_Sum' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=0
msg='Total_Sum' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=0
msg='valid_rows' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
msg='valid_rows' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
msg='valid_cols' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
msg='valid_cols' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
no version1
python varlen-block-attention-dev/tests/test_block_mask_bug.py
args.version1=False
tensor([[ True, True],
[False, False]], device='cuda:0')
_block_mask.kv_num_blocks.shape=torch.Size([1, 1, 2])
kv_num_blocks.shape=torch.Size([1, 1, 2])
msg='Total_Sum' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=4096
msg='Total_Sum' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=0
msg='valid_rows' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
msg='valid_rows' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
msg='valid_cols' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
msg='valid_cols' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
msg='Total_Sum' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=384
msg='Total_Sum' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=0
msg='valid_rows' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
msg='valid_rows' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
msg='valid_cols' BlockIdx=(0, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64
msg='valid_cols' BlockIdx=(1, 0, 0), ThreadIdx=(0, 0, 0): dtype=int value=64Expected behavior
No response
Additional context
No response