Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion csrc/flashmask_v2/flash_fwd_kernel_sm90.h
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ class FlashAttnFwdSm90 {
__shared__ int32_t flashmask_smem_[4 * kBlockN * CollectiveMainloop::kStages];
__shared__ __align__(128) int32_t flashmask_maxmin_smem[num_sch_stage * 8 * CollectiveMainloop::Flashmask_n_block_buffer_length * CollectiveMainloop::kNBlockStages];
__shared__ int32_t n_block_smem[num_sch_stage * CollectiveMainloop::Flashmask_n_block_buffer_length * CollectiveMainloop::kNBlockStages];
__shared__ __align__(128) int32_t blockmask_smem_[CollectiveMainloop::Blockmask_n_block_buffer_valid_length * CollectiveMainloop::kNBlockStages];
__shared__ __align__(128) int32_t blockmask_smem_[num_sch_stage * CollectiveMainloop::Blockmask_n_block_buffer_valid_length * CollectiveMainloop::kNBlockStages];
// When n_block_smem is full, we need to store the flag in the following extra flag storage, instead of allocating 4 more elements
__shared__ int32_t extra_flags[4]; // if num_sch_stage is 1, we actually only need two (kNBlockStages = 2)

Expand Down Expand Up @@ -329,6 +329,7 @@ class FlashAttnFwdSm90 {
if (valid_chunk)
pipeline_n_block.producer_acquire(n_block_pipe_write);
if (Is_blockmask) {
// if(m_block == 251 && threadIdx.x == 33) print("n_block_pipe_write.index : %d, cppl_stage: %d", n_block_pipe_write.index(), cppl_stage);
mainloop.load_blockmask(params.mainloop, seqlen_info, block_coord, reverse_chunk_idx, num_chunk,
blockmask_smem_ + CollectiveMainloop::Blockmask_n_block_buffer_valid_length * (n_block_pipe_write.index() + cppl_stage));
}
Expand Down
6 changes: 3 additions & 3 deletions csrc/flashmask_v2/mainloop_fwd_sm90_tma_gmma_ws.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -787,17 +787,17 @@ struct CollectiveMainloopFwdSm90 {
int32_t m_block = get<0>(block_coord);
const int thread_idx = threadIdx.x - 32;

const int chunks_size = total_num_chunks * Flashmask_n_block_buffer_length;
const int chunks_size = total_num_chunks * Blockmask_n_block_buffer_valid_length;
const int offset = (bidb * params.h_flashmask + bidh / params.h_h_flashmask_ratio) * chunks_size +
(chunks_size - (reverse_chunk_idx + 1) * Flashmask_n_block_buffer_length);
(chunks_size - (reverse_chunk_idx + 1) * Blockmask_n_block_buffer_valid_length);

const int nblock_seqlen = ((seqlen_info.seqlen_k + kBlockN - 1) / kBlockN + 3) & 0xfffffffc;

const int valid_block_nblock_seqlen = (seqlen_info.seqlen_k + n_block_dim - 1) / n_block_dim ; //xhy :maybe nblock_seqlen - 4
const int valid_block_mblock_seqlen = (seqlen_info.seqlen_q + m_block_dim - 1) / m_block_dim;
int blockmask_offset = (bidb * params.h_flashmask + bidh / params.h_h_flashmask_ratio) * valid_block_nblock_seqlen * valid_block_mblock_seqlen; // row_offset
blockmask_offset += m_block * valid_block_nblock_seqlen / m_factor;
blockmask_offset += std::max((valid_block_nblock_seqlen - (reverse_chunk_idx + 1) * Blockmask_n_block_buffer_valid_length), 0);
blockmask_offset += std::max((chunks_size - (reverse_chunk_idx + 1) * Blockmask_n_block_buffer_valid_length), 0);
int blockmask_length = Blockmask_n_block_buffer_valid_length < valid_block_nblock_seqlen ? Blockmask_n_block_buffer_valid_length : valid_block_nblock_seqlen;

//xhy: blockmask ptr maybe not 16-aligned, since load_blockmask is called before load_max_min, sync can be shared with load_max_min
Expand Down
Loading