Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
141 changes: 112 additions & 29 deletions flashmask/flash_mask/cute/flash_bwd_sm100.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ def __init__(
deterministic: bool = False,
cluster_size: int = 1,
use_2cta_instrs: bool = False,
is_split_d: bool = False,
):
# padding head_dim to a multiple of 64 to match head_dim_rounded in interface
hdim_multiple_of = 64
Expand All @@ -76,28 +77,51 @@ def __init__(
self.tile_m = tile_m
self.tile_n = tile_n
self.debug_print = False
self.is_split_d = is_split_d

assert self.tile_hdim <= 128 or (self.tile_hdim == 192 and self.tile_hdimv == 128)
assert self.tile_hdimv <= 128
if is_split_d:
assert self.tile_hdim > 192 and self.tile_hdim == self.tile_hdimv, (
"Split-D BWD requires head_dim > 192 and head_dim == head_dim_v"
)
# Split-D sub-GEMM dimension (each 256-dim op split into two 128-dim ops)
self.tile_hdim_split = self.tile_hdim // 2 # 128
else:
assert self.tile_hdim <= 128 or (self.tile_hdim == 192 and self.tile_hdimv == 128)
assert self.tile_hdimv <= 128

self.use_2cta_instrs = bool(use_2cta_instrs and cluster_size == 2)
self.cta_group_size = 2 if self.use_2cta_instrs else 1

assert self.tile_hdim != 192 or self.use_2cta_instrs, "Must use 2CTA for hdim 192"
if is_split_d:
assert self.use_2cta_instrs, "Split-D BWD requires 2CTA"

# CTA tiler
self.cta_tiler = (tile_n, tile_m, self.tile_hdim)
# S = K @ Q.T
self.mma_tiler_kq = (self.cta_group_size * tile_n, tile_m, self.tile_hdim)
# dP = V @ dO.T
self.mma_tiler_vdo = (self.cta_group_size * tile_n, tile_m, self.tile_hdimv)
# dV = P.T @ dO
self.mma_tiler_pdo = (self.cta_group_size * tile_n, self.tile_hdimv, tile_m)
# dK = dS.T @ Q
self.mma_tiler_dsq = (self.cta_group_size * tile_n, self.tile_hdim, tile_m)
# dQ = dS @ K
# 2-CTA: reduction dim is cluster-wide (tile_n * cta_group_size).
self.mma_tiler_dsk = (tile_m, self.tile_hdim, tile_n * self.cta_group_size)
if self.is_split_d:
# Split-D: each sub-GEMM operates on tile_hdim_split=128
# S_lo = K_lo @ Q_lo.T, S_hi = K_hi @ Q_hi.T (accumulated)
self.mma_tiler_kq = (self.cta_group_size * tile_n, tile_m, self.tile_hdim_split)
# dP_lo = V_lo @ dO_lo.T, dP_hi = V_hi @ dO_hi.T (accumulated)
self.mma_tiler_vdo = (self.cta_group_size * tile_n, tile_m, self.tile_hdim_split)
# dV_lo = P.T @ dO_lo, dV_hi = P.T @ dO_hi
self.mma_tiler_pdo = (self.cta_group_size * tile_n, self.tile_hdim_split, tile_m)
# dK_lo = dS.T @ Q_lo, dK_hi = dS.T @ Q_hi
self.mma_tiler_dsq = (self.cta_group_size * tile_n, self.tile_hdim_split, tile_m)
# dQ_lo = dS @ K_lo, dQ_hi = dS @ K_hi
self.mma_tiler_dsk = (tile_m, self.tile_hdim_split, tile_n * self.cta_group_size)
else:
# S = K @ Q.T
self.mma_tiler_kq = (self.cta_group_size * tile_n, tile_m, self.tile_hdim)
# dP = V @ dO.T
self.mma_tiler_vdo = (self.cta_group_size * tile_n, tile_m, self.tile_hdimv)
# dV = P.T @ dO
self.mma_tiler_pdo = (self.cta_group_size * tile_n, self.tile_hdimv, tile_m)
# dK = dS.T @ Q
self.mma_tiler_dsq = (self.cta_group_size * tile_n, self.tile_hdim, tile_m)
# dQ = dS @ K
# 2-CTA: reduction dim is cluster-wide (tile_n * cta_group_size).
self.mma_tiler_dsk = (tile_m, self.tile_hdim, tile_n * self.cta_group_size)

self.acc_dtype = Float32
self.startend_row_indices_dtype = Int32
Expand Down Expand Up @@ -163,7 +187,26 @@ def __init__(
# self.tmem_total = self.tmem_S_offset + self.tile_n
# assert self.tmem_total <= self.tmem_alloc_cols

if self.use_2cta_instrs and self.tile_hdim == 192 and self.tile_hdimv == 128:
if self.is_split_d:
# Split-D (d=dv=256) TMEM layout:
# |--- S/P ---|--- dV_low ---|--- dV_high ---|--- dP/dS ---|
# [0, 128) [128, 256) [256, 384) [384, 512)
# TMEM[0,128) is time-division multiplexed: S/P, dK_partial, dQ_partial
assert self.tile_m == 128
assert self.tile_n == 128
self.tmem_S_offset = 0
self.tmem_P_offset = 0 # overlap with S
self.tmem_dV_lo_offset = 128
self.tmem_dV_hi_offset = 256
self.tmem_dV_offset = self.tmem_dV_lo_offset # for compatibility
self.tmem_dP_offset = 384
self.tmem_dS_offset = 384 # overlap with dP
# dK uses GMEM atomic reduce (no persistent TMEM accumulation)
self.tmem_dK_offset = 0 # time-multiplexed with S/P
# dQ partial accumulation at S/P slot
self.tmem_dQ_offset = 0 # time-multiplexed with S/P
self.dK_as_reduce = True # dK accumulated via GMEM atomic reduce
elif self.use_2cta_instrs and self.tile_hdim == 192 and self.tile_hdimv == 128:
assert self.tile_m == 128
assert self.tile_n == 128
self.tmem_dV_offset = 0
Expand All @@ -186,6 +229,9 @@ def __init__(
self.tmem_dK_offset = self.tmem_dP_offset + self.tile_m
self.tmem_dS_offset = self.tmem_dP_offset # overlap with dP

if not self.is_split_d:
self.dK_as_reduce = False

if (not is_causal and not is_local) or deterministic:
self.num_regs_reduce = 136 if self.use_2cta_instrs else 152
self.num_regs_compute = 136
Expand All @@ -204,6 +250,13 @@ def __init__(
self.num_regs_load = 128 - 24
self.num_regs_mma = self.num_regs_load

if const_expr(self.is_split_d):
# Split-D (d=256): similar to d=192 register pressure
self.num_regs_reduce = 136
self.num_regs_compute = 136
self.num_regs_load = 104
self.num_regs_mma = 104

assert (
self.num_regs_reduce
+ self.num_regs_compute * 2
Expand All @@ -213,6 +266,13 @@ def __init__(

self.buffer_align_bytes = 1024

# Split-D: byte offset from low to high half in SMEM buffers
if self.is_split_d:
# For K-major SMEM (e.g., sQ: tile_m × tile_hdim), high half starts at
# tile_hdim_split * element_size bytes from the start of each row
# Stored here; used in the MMA loop to compute sub-GEMM offsets
self.split_d_smem_k_offset_elems = self.tile_hdim_split # 128 elements

def _setup_attributes(self):
self.Q_stage = 1 if self.use_2cta_instrs else 2
self.dO_stage = 1
Expand All @@ -221,7 +281,12 @@ def _setup_attributes(self):
self.sdKVaccum_stage = 2
# number of tma reduce adds per dQacc mma
# todo: try 32/1 or 48/2 for 2cta d=192 dv=128
if self.use_2cta_instrs and self.tile_hdim == 192:
if self.is_split_d:
# Split-D: dQ reduce operates on tile_hdim_split=128 (each half independently)
self.dQ_reduce_ncol_t2r = 32
self.dQ_reduce_ncol = 16 if self.deterministic else 8
self.sdQaccum_stage = 2 if self.deterministic else 4
elif self.use_2cta_instrs and self.tile_hdim == 192:
self.dQ_reduce_ncol_t2r = 32
self.dQ_reduce_ncol = 24 if not self.is_causal else 32
self.sdQaccum_stage = 2 if not self.is_causal else 1
Expand All @@ -234,9 +299,11 @@ def _setup_attributes(self):
self.dQ_reduce_ncol = 32
self.sdQaccum_stage = 64 // self.dQ_reduce_ncol
self.dQ_reduce_ncol_t2r = self.dQ_reduce_ncol
assert (self.tile_hdim // self.cta_group_size) % self.dQ_reduce_ncol == 0
self.dQaccum_reduce_stage = self.tile_hdim // self.dQ_reduce_ncol
self.dQaccum_reduce_stage_t2r = self.tile_hdim // self.dQ_reduce_ncol_t2r
# For Split-D, reduce operations work on half dimensions
tile_hdim_for_reduce = self.tile_hdim_split if self.is_split_d else self.tile_hdim
assert (tile_hdim_for_reduce // self.cta_group_size) % self.dQ_reduce_ncol == 0
self.dQaccum_reduce_stage = tile_hdim_for_reduce // self.dQ_reduce_ncol
self.dQaccum_reduce_stage_t2r = tile_hdim_for_reduce // self.dQ_reduce_ncol_t2r
self.cluster_reduce_dQ = False and cute.size(self.cluster_shape_mn) > 1
# number of tma reduce adds for dKacc and dVacc epilogue
self.dK_reduce_ncol = 32
Expand Down Expand Up @@ -298,80 +365,96 @@ def _get_tiled_mma(self):
return tiled_mma_S, tiled_mma_dP, tiled_mma_dK, tiled_mma_dV, tiled_mma_dQ

def _setup_smem_layout(self):
# For Split-D: SMEM must hold full 256-dim data, but MMA operates on 128-dim halves.
# We define "full" tilers for SMEM layout computation and use the split tilers for MMA.
if self.is_split_d:
# Full tilers for SMEM allocation (tile_hdim=256 in K-dimension)
mma_tiler_kq_full = (self.cta_group_size * self.tile_n, self.tile_m, self.tile_hdim)
mma_tiler_vdo_full = (self.cta_group_size * self.tile_n, self.tile_m, self.tile_hdimv)
mma_tiler_pdo_full = (self.cta_group_size * self.tile_n, self.tile_hdimv, self.tile_m)
mma_tiler_dsq_full = (self.cta_group_size * self.tile_n, self.tile_hdim, self.tile_m)
mma_tiler_dsk_full = (self.tile_m, self.tile_hdim, self.tile_n * self.cta_group_size)
else:
mma_tiler_kq_full = self.mma_tiler_kq
mma_tiler_vdo_full = self.mma_tiler_vdo
mma_tiler_pdo_full = self.mma_tiler_pdo
mma_tiler_dsq_full = self.mma_tiler_dsq
mma_tiler_dsk_full = self.mma_tiler_dsk

# S.T = K @ Q.T
sK_layout = sm100_utils_basic.make_smem_layout_a(
self.tiled_mma_S,
self.mma_tiler_kq,
mma_tiler_kq_full,
self.k_dtype,
1,
)
self.sK_layout = cute.slice_(sK_layout, (None, None, None, 0))
self.sQ_layout = sm100_utils_basic.make_smem_layout_b(
self.tiled_mma_S,
self.mma_tiler_kq,
mma_tiler_kq_full,
self.q_dtype,
self.Q_stage,
)
# dP.T = V @ dO.T
sV_layout = sm100_utils_basic.make_smem_layout_a(
self.tiled_mma_dP,
self.mma_tiler_vdo,
mma_tiler_vdo_full,
self.v_dtype,
1,
)
self.sV_layout = cute.slice_(sV_layout, (None, None, None, 0))
self.sdOt_layout = sm100_utils_basic.make_smem_layout_b(
self.tiled_mma_dP,
self.mma_tiler_vdo,
mma_tiler_vdo_full,
self.do_dtype,
self.dO_stage,
)
# dV += P.T @ dO
tP_layout = sm100_utils_basic.make_smem_layout_a(
self.tiled_mma_dV,
self.mma_tiler_pdo,
mma_tiler_pdo_full,
self.do_dtype,
1,
)
self.tP_layout = cute.slice_(tP_layout, (None, None, None, 0))
self.sdO_layout = sm100_utils_basic.make_smem_layout_b(
self.tiled_mma_dV,
self.mma_tiler_pdo,
mma_tiler_pdo_full,
self.do_dtype,
self.dO_stage,
)
# dK += dS.T @ Q
sdSt_layout = sm100_utils_basic.make_smem_layout_a(
self.tiled_mma_dK,
self.mma_tiler_dsq,
mma_tiler_dsq_full,
self.ds_dtype,
1,
)
self.sdSt_layout = cute.slice_(sdSt_layout, (None, None, None, 0))
tdS_layout = sm100_utils_basic.make_smem_layout_a(
self.tiled_mma_dK,
self.mma_tiler_dsq,
mma_tiler_dsq_full,
self.ds_dtype,
1,
)
self.tdS_layout = cute.slice_(tdS_layout, (None, None, None, 0))
self.sQt_layout = sm100_utils_basic.make_smem_layout_b(
self.tiled_mma_dK,
self.mma_tiler_dsq,
mma_tiler_dsq_full,
self.q_dtype,
self.Q_stage,
)
# dQ = dS @ K
sdS_layout = sm100_utils_basic.make_smem_layout_a(
self.tiled_mma_dQ,
self.mma_tiler_dsk,
mma_tiler_dsk_full,
self.ds_dtype,
1,
)
self.sdS_layout = cute.slice_(sdS_layout, (None, None, None, 0))
sKt_layout = sm100_utils_basic.make_smem_layout_b(
self.tiled_mma_dQ,
self.mma_tiler_dsk,
mma_tiler_dsk_full,
self.k_dtype,
1,
)
Expand Down
Loading