From e03485dd93ebb196269b649c518256daed7a159c Mon Sep 17 00:00:00 2001 From: Marceli Fylcek Date: Mon, 20 Apr 2026 07:30:16 -0700 Subject: [PATCH 1/5] triton op Signed-off-by: Marceli Fylcek --- .../mamba/test_selective_scan_triton.py | 378 +++++++++++++ vllm/_custom_ops.py | 31 ++ .../layers/mamba/ops/selective_scan_triton.py | 510 ++++++++++++++++++ 3 files changed, 919 insertions(+) create mode 100644 tests/kernels/mamba/test_selective_scan_triton.py create mode 100644 vllm/model_executor/layers/mamba/ops/selective_scan_triton.py diff --git a/tests/kernels/mamba/test_selective_scan_triton.py b/tests/kernels/mamba/test_selective_scan_triton.py new file mode 100644 index 000000000000..69890c3635db --- /dev/null +++ b/tests/kernels/mamba/test_selective_scan_triton.py @@ -0,0 +1,378 @@ +# SPDX-License-Identifier: Apache-2.0 +"""Test script for the Triton selective scan implementation on XPU.""" + +import torch +import torch.nn.functional as F + +from vllm.model_executor.layers.mamba.ops.selective_scan_triton import ( + selective_scan_fwd_triton, +) + + +def selective_scan_ref( + u, delta, A, B, C, D=None, z=None, delta_bias=None, + delta_softplus=False, prev_state=None, +): + """Reference implementation (pure PyTorch, sequential scan).""" + dtype_in = u.dtype + u = u.float() + delta = delta.float() + if delta_bias is not None: + delta = delta + delta_bias[..., None].float() + if delta_softplus: + delta = F.softplus(delta) + batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] + B = B.float() + C = C.float() + x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state.float() + ys = [] + deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A)) + if B.dim() == 3: + deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u) + else: + from einops import repeat + B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) + deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u) + if C.dim() == 4: + from einops import repeat + C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) + last_state = None + for i in range(u.shape[2]): + x = deltaA[:, :, i] * x + deltaB_u[:, :, i] + if C.dim() == 3: + y = torch.einsum("bdn,bn->bd", x, C[:, :, i]) + else: + y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i]) + if i == u.shape[2] - 1: + last_state = x + ys.append(y) + y = torch.stack(ys, dim=2) + out = y if D is None else y + u * D[None, :, None] + if z is not None: + out = out * F.silu(z.float()) + out = out.to(dtype=dtype_in) + return out, last_state + + +def test_basic(device, itype, seqlen, has_z, has_D, has_delta_bias, + delta_softplus, varBC_groups, batch_size=1, dim=4, dstate=8): + """Test basic selective scan correctness.""" + torch.manual_seed(42) + wtype = torch.float32 + + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + + A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype) + B_shape = [batch_size, varBC_groups, dstate, seqlen] + B = torch.randn(B_shape, device=device, dtype=itype) + C_shape = [batch_size, varBC_groups, dstate, seqlen] + C = torch.randn(C_shape, device=device, dtype=itype) + D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None + z = (torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) + if has_z else None) + delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32) + if has_delta_bias else None) + u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) + delta = 0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype) + ssm_states = torch.zeros(batch_size, dim, dstate, device=device, dtype=itype) + + # Reference + u_ref = u.clone() + delta_ref = delta.clone() + z_ref = z.clone() if z is not None else None + out_ref, state_ref = selective_scan_ref( + u_ref, delta_ref, A.clone(), B.clone(), C.clone(), + D=D.clone() if D is not None else None, + z=z_ref, + delta_bias=delta_bias, + delta_softplus=delta_softplus, + ) + + # Prepare inputs for our kernel (need contiguous, matching expected shapes) + u_test = u.clone() + delta_test = delta.clone() + z_test = z.clone() if z is not None else None + B_test = B.clone() + C_test = C.clone() + ssm_states_test = ssm_states.clone() + + selective_scan_fwd_triton( + u_test, delta_test, A, B_test, C_test, + D, z_test, delta_bias, + delta_softplus, + query_start_loc=None, + cache_indices=None, + has_initial_state=None, + ssm_states=ssm_states_test, + null_block_id=-1, + block_size=2048, + ) + + # Get output (z if has_z, else delta) + out_test = z_test if has_z else delta_test + + # Compare + if not torch.allclose(out_test, out_ref, rtol=rtol, atol=atol): + max_diff = (out_test - out_ref).abs().max().item() + print(f" FAIL: max_diff={max_diff:.6e} (rtol={rtol}, atol={atol})") + return False + + if not torch.allclose(ssm_states_test, state_ref.to(itype), rtol=rtol, atol=atol): + max_diff = (ssm_states_test - state_ref.to(itype)).abs().max().item() + print(f" FAIL (states): max_diff={max_diff:.6e}") + return False + + return True + + +def test_with_initial_state(device, itype, seqlen): + """Test with initial state (chunked scan simulation).""" + torch.manual_seed(42) + dim, dstate, batch_size = 4, 8, 1 + varBC_groups = 1 + + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + + A = -0.5 * torch.rand(dim, dstate, device=device, dtype=torch.float32) + B = torch.randn(batch_size, varBC_groups, dstate, seqlen, device=device, dtype=itype) + C = torch.randn(batch_size, varBC_groups, dstate, seqlen, device=device, dtype=itype) + D = torch.randn(dim, device=device, dtype=torch.float32) + z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) + delta_bias = 0.5 * torch.rand(dim, device=device, dtype=torch.float32) + u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) + delta = 0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype) + + # Reference: full scan + out_ref, state_ref = selective_scan_ref( + u.clone(), delta.clone(), A.clone(), B.clone(), C.clone(), + D=D.clone(), z=z.clone(), delta_bias=delta_bias, delta_softplus=True, + ) + + # Triton: two-chunk scan + mid = seqlen // 2 + # First chunk + ssm_states = torch.zeros(batch_size, dim, dstate, device=device, dtype=itype) + delta1 = delta[..., :mid].clone() + z1 = z[..., :mid].clone() + selective_scan_fwd_triton( + u[..., :mid].contiguous(), delta1, A, B[..., :mid].contiguous(), + C[..., :mid].contiguous(), D, z1, delta_bias, True, + None, None, None, ssm_states, -1, 2048, + ) + out1 = z1 + + # Second chunk with initial state + delta2 = delta[..., mid:].clone() + z2 = z[..., mid:].clone() + selective_scan_fwd_triton( + u[..., mid:].contiguous(), delta2, A, B[..., mid:].contiguous(), + C[..., mid:].contiguous(), D, z2, delta_bias, True, + None, None, + torch.ones(batch_size, device=device, dtype=torch.bool), + ssm_states, -1, 2048, + ) + out2 = z2 + + out_test = torch.cat([out1, out2], dim=-1) + + if not torch.allclose(out_test, out_ref, rtol=rtol, atol=atol): + max_diff = (out_test - out_ref).abs().max().item() + print(f" FAIL: max_diff={max_diff:.6e}") + return False + + if not torch.allclose(ssm_states, state_ref.to(itype), rtol=rtol, atol=atol): + max_diff = (ssm_states - state_ref.to(itype)).abs().max().item() + print(f" FAIL (states): max_diff={max_diff:.6e}") + return False + + return True + + +def test_varlen(device, itype, seqlens): + """Test varlen mode with query_start_loc.""" + torch.manual_seed(42) + dim, dstate = 4, 8 + n_groups = 1 + batch_size = len(seqlens) + total_len = sum(seqlens) + + rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) + if itype == torch.bfloat16: + rtol, atol = 3e-2, 5e-2 + + A = -0.5 * torch.rand(dim, dstate, device=device, dtype=torch.float32) + D = torch.randn(dim, device=device, dtype=torch.float32) + delta_bias = 0.5 * torch.rand(dim, device=device, dtype=torch.float32) + + # Varlen layout: u=(dim, total_len), B=(n_groups, dstate, total_len) + u = torch.randn(dim, total_len, device=device, dtype=itype) + delta = 0.5 * torch.rand(dim, total_len, device=device, dtype=itype) + z = torch.randn(dim, total_len, device=device, dtype=itype) + B = torch.randn(n_groups, dstate, total_len, device=device, dtype=itype) + C = torch.randn(n_groups, dstate, total_len, device=device, dtype=itype) + ssm_states = torch.zeros(batch_size, dim, dstate, device=device, dtype=itype) + + query_start_loc = torch.zeros(batch_size + 1, device=device, dtype=torch.int32) + for i, sl in enumerate(seqlens): + query_start_loc[i + 1] = query_start_loc[i] + sl + + # Reference: process each sequence separately in batch mode + out_ref_parts = [] + state_refs = [] + offset = 0 + for i, sl in enumerate(seqlens): + u_i = u[:, offset:offset + sl].unsqueeze(0) # (1, dim, sl) + delta_i = delta[:, offset:offset + sl].unsqueeze(0) + z_i = z[:, offset:offset + sl].unsqueeze(0) + B_i = B[:, :, offset:offset + sl].unsqueeze(0) # (1, n_groups, dstate, sl) + C_i = C[:, :, offset:offset + sl].unsqueeze(0) + out_i, state_i = selective_scan_ref( + u_i, delta_i, A.clone(), B_i, C_i, + D=D.clone(), z=z_i, delta_bias=delta_bias, delta_softplus=True, + ) + out_ref_parts.append(out_i.squeeze(0)) # (dim, sl) + state_refs.append(state_i.squeeze(0)) + offset += sl + out_ref = torch.cat(out_ref_parts, dim=-1) # (dim, total_len) + state_ref = torch.stack(state_refs, dim=0) # (batch, dim, dstate) + + # Triton varlen + delta_test = delta.clone() + z_test = z.clone() + ssm_states_test = ssm_states.clone() + selective_scan_fwd_triton( + u.clone(), delta_test, A, B.clone(), C.clone(), + D, z_test, delta_bias, True, + query_start_loc, None, None, ssm_states_test, -1, 2048, + ) + out_test = z_test # z is present + + if not torch.allclose(out_test, out_ref, rtol=rtol, atol=atol): + max_diff = (out_test - out_ref).abs().max().item() + print(f" FAIL: max_diff={max_diff:.6e}") + return False + + if not torch.allclose(ssm_states_test, state_ref.to(itype), rtol=rtol, atol=atol): + max_diff = (ssm_states_test - state_ref.to(itype)).abs().max().item() + print(f" FAIL (states): max_diff={max_diff:.6e}") + return False + + return True + + +def main(): + device = "xpu:0" + print(f"Testing on {device}") + print(f"Device: {torch.xpu.get_device_name(0)}") + print() + + total = 0 + passed = 0 + + # Test 1: Basic configurations + print("=== Basic tests ===") + for itype in [torch.float32, torch.bfloat16]: + for seqlen in [16, 128, 512]: + for has_z in [True, False]: + for has_D in [True, False]: + for varBC_groups in [1, 2]: + total += 1 + name = (f"basic itype={itype}, seqlen={seqlen}, " + f"z={has_z}, D={has_D}, groups={varBC_groups}") + try: + ok = test_basic( + device, itype, seqlen, + has_z=has_z, has_D=has_D, + has_delta_bias=True, + delta_softplus=True, + varBC_groups=varBC_groups, + ) + if ok: + passed += 1 + print(f" PASS: {name}") + else: + print(f" FAIL: {name}") + except Exception as e: + print(f" ERROR: {name}: {e}") + import traceback + traceback.print_exc() + + # Test 2: Chunked scan with initial state + print("\n=== Chunked scan (initial state) tests ===") + for itype in [torch.float32, torch.bfloat16]: + for seqlen in [64, 256]: + total += 1 + name = f"chunked itype={itype}, seqlen={seqlen}" + try: + ok = test_with_initial_state(device, itype, seqlen) + if ok: + passed += 1 + print(f" PASS: {name}") + else: + print(f" FAIL: {name}") + except Exception as e: + print(f" ERROR: {name}: {e}") + import traceback + traceback.print_exc() + + # Test 3: Varlen mode + print("\n=== Varlen tests ===") + for itype in [torch.float32, torch.bfloat16]: + for seqlens in [[32, 16], [64, 32, 16]]: + total += 1 + name = f"varlen itype={itype}, seqlens={seqlens}" + try: + ok = test_varlen(device, itype, seqlens) + if ok: + passed += 1 + print(f" PASS: {name}") + else: + print(f" FAIL: {name}") + except Exception as e: + print(f" ERROR: {name}: {e}") + import traceback + traceback.print_exc() + + # Test 4: Larger dimensions (closer to real model) + print("\n=== Large dimension tests ===") + for itype in [torch.bfloat16]: + for dim in [256, 1024]: + for seqlen in [128, 512]: + total += 1 + name = f"large dim={dim}, seqlen={seqlen}, itype={itype}" + try: + ok = test_basic( + device, itype, seqlen, + has_z=True, has_D=True, + has_delta_bias=True, + delta_softplus=True, + varBC_groups=1, + dim=dim, dstate=16, + ) + if ok: + passed += 1 + print(f" PASS: {name}") + else: + print(f" FAIL: {name}") + except Exception as e: + print(f" ERROR: {name}: {e}") + import traceback + traceback.print_exc() + + print(f"\n{'='*60}") + print(f"Results: {passed}/{total} tests passed") + if passed == total: + print("ALL TESTS PASSED!") + else: + print(f"FAILURES: {total - passed}") + + return passed == total + + +if __name__ == "__main__": + success = main() + exit(0 if success else 1) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index cb4bd28f2677..262a62aa38cb 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2184,6 +2184,37 @@ def selective_scan_fwd( cu_chunk_seqlen: torch.Tensor | None = None, last_chunk_indices: torch.Tensor | None = None, ): + from vllm.platforms import current_platform + + if current_platform.is_xpu(): + from vllm.model_executor.layers.mamba.ops.selective_scan_triton import ( + selective_scan_fwd_triton, + ) + + selective_scan_fwd_triton( + u, + delta, + A, + B, + C, + D_, + z_, + delta_bias_, + delta_softplus, + query_start_loc, + cache_indices, + has_initial_state, + ssm_states, + null_block_id, + block_size, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + initial_state_idx, + cu_chunk_seqlen, + last_chunk_indices, + ) + return + torch.ops._C.selective_scan_fwd( u, delta, diff --git a/vllm/model_executor/layers/mamba/ops/selective_scan_triton.py b/vllm/model_executor/layers/mamba/ops/selective_scan_triton.py new file mode 100644 index 000000000000..a13f7498e11d --- /dev/null +++ b/vllm/model_executor/layers/mamba/ops/selective_scan_triton.py @@ -0,0 +1,510 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +""" +Triton implementation of the Mamba selective scan forward pass. + +This provides a platform-portable alternative to the CUDA-only +selective_scan_fwd kernel. It supports both varlen and non-varlen modes, +with optional z-gating, D bias, delta bias, and delta softplus. + +The kernel uses a sequential scan approach: each program handles one +(batch, dim) pair and iterates over the sequence length, maintaining +the SSM state vector across positions. Parallelism comes from launching +batch * dim programs concurrently. +""" + +import torch + +from vllm.triton_utils import tl, triton + + +@triton.jit +def _softplus(x): + return tl.where(x <= 20.0, tl.math.log(tl.math.exp(x) + 1.0), x) + + +@triton.jit +def _selective_scan_fwd_kernel( + # Pointers to input tensors + u_ptr, + delta_ptr, + A_ptr, + B_ptr, + C_ptr, + D_ptr, + z_ptr, + delta_bias_ptr, + # Pointers to output tensors (out aliases delta, out_z aliases z) + out_ptr, + out_z_ptr, + # SSM states + ssm_states_ptr, + # Optional pointers + query_start_loc_ptr, + cache_indices_ptr, + has_initial_state_ptr, + # APC pointers + block_idx_first_ptr, + block_idx_last_ptr, + initial_state_idx_ptr, + cu_chunk_seqlen_ptr, + last_chunk_indices_ptr, + # Dimensions + batch: tl.int32, + dim: tl.int32, + seqlen: tl.int32, + dstate: tl.int32, + n_groups: tl.int32, + dim_ngroups_ratio: tl.int32, + # Strides for u (and out, since out = delta which has same layout) + u_batch_stride: tl.int64, + u_d_stride: tl.int64, + # Strides for delta + delta_batch_stride: tl.int64, + delta_d_stride: tl.int64, + # Strides for A + A_d_stride: tl.int64, + A_dstate_stride: tl.int64, + # Strides for B + B_batch_stride: tl.int64, + B_group_stride: tl.int64, + B_dstate_stride: tl.int64, + # Strides for C + C_batch_stride: tl.int64, + C_group_stride: tl.int64, + C_dstate_stride: tl.int64, + # Strides for z + z_batch_stride: tl.int64, + z_d_stride: tl.int64, + # Strides for out + out_batch_stride: tl.int64, + out_d_stride: tl.int64, + # Strides for out_z + out_z_batch_stride: tl.int64, + out_z_d_stride: tl.int64, + # Strides for ssm_states + ssm_batch_stride: tl.int64, + ssm_dim_stride: tl.int64, + ssm_dstate_stride: tl.int64, + # Cache strides + cache_indices_stride: tl.int64, + # Scalar params + null_block_id: tl.int64, + block_size: tl.int32, + # Compile-time constants + delta_softplus: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_DELTA_BIAS: tl.constexpr, + IS_VARLEN: tl.constexpr, + HAS_CACHE_INDICES: tl.constexpr, + CACHE_ENABLED: tl.constexpr, + BLOCK_DSTATE: tl.constexpr, +): + batch_idx = tl.program_id(0) + dim_idx = tl.program_id(1) + group_idx = dim_idx // dim_ngroups_ratio + + # Determine sequence boundaries + if IS_VARLEN: + seq_start = tl.load(query_start_loc_ptr + batch_idx).to(tl.int32) + seq_end = tl.load(query_start_loc_ptr + batch_idx + 1).to(tl.int32) + actual_seqlen = seq_end - seq_start + else: + seq_start = 0 + actual_seqlen = seqlen + + # Determine cache index for ssm_states + if CACHE_ENABLED: + init_state_idx = tl.load(initial_state_idx_ptr + batch_idx).to(tl.int32) + load_cache_slot = tl.load( + cache_indices_ptr + batch_idx * cache_indices_stride + init_state_idx + ).to(tl.int64) + if load_cache_slot == null_block_id: + return + elif HAS_CACHE_INDICES: + cache_index = tl.load(cache_indices_ptr + batch_idx).to(tl.int64) + if cache_index == null_block_id: + return + load_cache_slot = cache_index + else: + load_cache_slot = batch_idx.to(tl.int64) + + # Load D value + D_val = 0.0 + if HAS_D: + D_val = tl.load(D_ptr + dim_idx).to(tl.float32) + + # Load delta_bias value + delta_bias_val = 0.0 + if HAS_DELTA_BIAS: + delta_bias_val = tl.load(delta_bias_ptr + dim_idx).to(tl.float32) + + # Load A values for this dim - shape (dstate,) + dstate_offs = tl.arange(0, BLOCK_DSTATE) + dstate_mask = dstate_offs < dstate + A_vals = tl.load( + A_ptr + dim_idx * A_d_stride + dstate_offs * A_dstate_stride, + mask=dstate_mask, + other=0.0, + ).to(tl.float32) + + # Initialize state vector + state = tl.zeros((BLOCK_DSTATE,), dtype=tl.float32) + + # Load initial state if available + has_init = False + if has_initial_state_ptr is not None: + has_init = tl.load(has_initial_state_ptr + batch_idx) + if has_init: + state = tl.load( + ssm_states_ptr + + load_cache_slot * ssm_batch_stride + + dim_idx * ssm_dim_stride + + dstate_offs * ssm_dstate_stride, + mask=dstate_mask, + other=0.0, + ).to(tl.float32) + + # Compute base addresses for u and delta + if IS_VARLEN: + u_base = u_ptr + dim_idx * u_d_stride + seq_start * u_batch_stride + delta_base = ( + delta_ptr + dim_idx * delta_d_stride + seq_start * delta_batch_stride + ) + out_base = ( + out_ptr + dim_idx * out_d_stride + seq_start * out_batch_stride + ) + B_base = B_ptr + group_idx * B_group_stride + seq_start * B_batch_stride + C_base = C_ptr + group_idx * C_group_stride + seq_start * C_batch_stride + else: + u_base = u_ptr + batch_idx * u_batch_stride + dim_idx * u_d_stride + delta_base = ( + delta_ptr + batch_idx * delta_batch_stride + dim_idx * delta_d_stride + ) + out_base = ( + out_ptr + batch_idx * out_batch_stride + dim_idx * out_d_stride + ) + B_base = B_ptr + batch_idx * B_batch_stride + group_idx * B_group_stride + C_base = C_ptr + batch_idx * C_batch_stride + group_idx * C_group_stride + + if HAS_Z: + if IS_VARLEN: + z_base = z_ptr + dim_idx * z_d_stride + seq_start * z_batch_stride + out_z_base = ( + out_z_ptr + + dim_idx * out_z_d_stride + + seq_start * out_z_batch_stride + ) + else: + z_base = z_ptr + batch_idx * z_batch_stride + dim_idx * z_d_stride + out_z_base = ( + out_z_ptr + + batch_idx * out_z_batch_stride + + dim_idx * out_z_d_stride + ) + + # Determine chunk boundaries for APC mode + if CACHE_ENABLED: + last_chunk_idx = tl.load(last_chunk_indices_ptr + batch_idx).to(tl.int32) + if batch_idx == 0: + first_chunk_idx = 0 + else: + first_chunk_idx = ( + tl.load(last_chunk_indices_ptr + batch_idx - 1).to(tl.int32) + 1 + ) + n_chunks = last_chunk_idx - first_chunk_idx + 1 + first_chunk_tokens = ( + tl.load(cu_chunk_seqlen_ptr + first_chunk_idx + 1).to(tl.int32) + - tl.load(cu_chunk_seqlen_ptr + first_chunk_idx).to(tl.int32) + ) + block_idx_first = tl.load(block_idx_first_ptr + batch_idx).to(tl.int32) + chunk_start_offset = 0 + if n_chunks > 1 and first_chunk_tokens < block_size: + chunk_start_offset = block_size - first_chunk_tokens + current_position = block_idx_first * block_size + chunk_start_offset + else: + n_chunks = 1 + first_chunk_idx = 0 + + # Sequential scan over the sequence + tokens_processed = 0 + for chunk in range(0, n_chunks if CACHE_ENABLED else 1): + if CACHE_ENABLED: + chunk_tokens = ( + tl.load( + cu_chunk_seqlen_ptr + first_chunk_idx + chunk + 1 + ).to(tl.int32) + - tl.load( + cu_chunk_seqlen_ptr + first_chunk_idx + chunk + ).to(tl.int32) + ) + else: + chunk_tokens = actual_seqlen + + for local_pos in range(chunk_tokens): + pos = tokens_processed + local_pos + # Load u value + u_val = tl.load(u_base + pos).to(tl.float32) + + # Load delta value + delta_val = tl.load(delta_base + pos).to(tl.float32) + + # Apply delta bias + if HAS_DELTA_BIAS: + delta_val = delta_val + delta_bias_val + + # Apply softplus + if delta_softplus: + delta_val = _softplus(delta_val) + + delta_u = delta_val * u_val + + # Compute dA = exp(delta * A) for all dstate elements + dA = tl.exp(delta_val * A_vals) + + # Load B values for this position + B_vals = tl.load( + B_base + dstate_offs * B_dstate_stride + pos, + mask=dstate_mask, + other=0.0, + ).to(tl.float32) + + # Load C values for this position + C_vals = tl.load( + C_base + dstate_offs * C_dstate_stride + pos, + mask=dstate_mask, + other=0.0, + ).to(tl.float32) + + # Update state: state = dA * state + delta * u * B + state = dA * state + delta_u * B_vals + + # Compute output: out = sum(state * C) + D * u + out_val = tl.sum(state * C_vals, axis=0) + if HAS_D: + out_val = out_val + D_val * u_val + + # Store output + tl.store(out_base + pos, out_val.to(out_ptr.dtype.element_ty)) + + if HAS_Z: + z_val = tl.load(z_base + pos).to(tl.float32) + out_z_val = out_val * z_val / (1.0 + tl.exp(-z_val)) + tl.store( + out_z_base + pos, + out_z_val.to(out_z_ptr.dtype.element_ty), + ) + + tokens_processed += chunk_tokens + + # Store intermediate state for APC mode + if CACHE_ENABLED: + if chunk == n_chunks - 1: + store_slot = tl.load( + cache_indices_ptr + + batch_idx * cache_indices_stride + + tl.load(block_idx_last_ptr + batch_idx).to(tl.int32) + ).to(tl.int64) + else: + block_idx_done = ( + current_position + chunk_tokens - 1 + ) // block_size + store_slot = tl.load( + cache_indices_ptr + + batch_idx * cache_indices_stride + + block_idx_done + ).to(tl.int64) + + tl.store( + ssm_states_ptr + + store_slot * ssm_batch_stride + + dim_idx * ssm_dim_stride + + dstate_offs * ssm_dstate_stride, + state.to(ssm_states_ptr.dtype.element_ty), + mask=dstate_mask, + ) + current_position += chunk_tokens + + # Store final state for non-APC mode + if not CACHE_ENABLED: + tl.store( + ssm_states_ptr + + load_cache_slot * ssm_batch_stride + + dim_idx * ssm_dim_stride + + dstate_offs * ssm_dstate_stride, + state.to(ssm_states_ptr.dtype.element_ty), + mask=dstate_mask, + ) + + +def selective_scan_fwd_triton( + u: torch.Tensor, + delta: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + D_: torch.Tensor | None, + z_: torch.Tensor | None, + delta_bias_: torch.Tensor | None, + delta_softplus: bool, + query_start_loc: torch.Tensor | None, + cache_indices: torch.Tensor | None, + has_initial_state: torch.Tensor | None, + ssm_states: torch.Tensor, + null_block_id: int, + block_size: int = 1024, + block_idx_first_scheduled_token: torch.Tensor | None = None, + block_idx_last_scheduled_token: torch.Tensor | None = None, + initial_state_idx: torch.Tensor | None = None, + cu_chunk_seqlen: torch.Tensor | None = None, + last_chunk_indices: torch.Tensor | None = None, +): + """ + Triton implementation of selective scan forward pass. + + This writes output in-place to delta (when z is None) or z (when z is + provided), matching the CUDA kernel's behavior. + + See selective_scan_fn() in mamba_ssm.py for parameter documentation. + """ + varlen = query_start_loc is not None + batch_size = ( + (query_start_loc.shape[0] - 1) if varlen else u.shape[0] + ) + dim = u.shape[0] if varlen else u.shape[1] + total_seqlen = u.shape[1] if varlen else u.shape[2] + dstate = A.size(1) + n_groups = B.size(0) if varlen else B.size(1) + dim_ngroups_ratio = dim // n_groups + + has_z = z_ is not None + has_D = D_ is not None + has_delta_bias = delta_bias_ is not None + has_cache_indices = cache_indices is not None + cache_enabled = block_idx_first_scheduled_token is not None + + # out and out_z alias delta and z respectively + out = delta + out_z = z_ if has_z else delta # dummy, won't be used if not has_z + + BLOCK_DSTATE = triton.next_power_of_2(dstate) + + # Compute strides + if varlen: + u_batch_stride = u.stride(1) + u_d_stride = u.stride(0) + delta_batch_stride = delta.stride(1) + delta_d_stride = delta.stride(0) + B_batch_stride = B.stride(2) + B_group_stride = B.stride(0) + B_dstate_stride = B.stride(1) + C_batch_stride = C.stride(2) + C_group_stride = C.stride(0) + C_dstate_stride = C.stride(1) + out_batch_stride = out.stride(1) + out_d_stride = out.stride(0) + if has_z: + z_batch_stride = z_.stride(1) + z_d_stride = z_.stride(0) + out_z_batch_stride = out_z.stride(1) + out_z_d_stride = out_z.stride(0) + else: + z_batch_stride = 0 + z_d_stride = 0 + out_z_batch_stride = 0 + out_z_d_stride = 0 + else: + u_batch_stride = u.stride(0) + u_d_stride = u.stride(1) + delta_batch_stride = delta.stride(0) + delta_d_stride = delta.stride(1) + B_batch_stride = B.stride(0) + B_group_stride = B.stride(1) + B_dstate_stride = B.stride(2) + C_batch_stride = C.stride(0) + C_group_stride = C.stride(1) + C_dstate_stride = C.stride(2) + out_batch_stride = out.stride(0) + out_d_stride = out.stride(1) + if has_z: + z_batch_stride = z_.stride(0) + z_d_stride = z_.stride(1) + out_z_batch_stride = out_z.stride(0) + out_z_d_stride = out_z.stride(1) + else: + z_batch_stride = 0 + z_d_stride = 0 + out_z_batch_stride = 0 + out_z_d_stride = 0 + + ssm_batch_stride = ssm_states.stride(0) + ssm_dim_stride = ssm_states.stride(1) + ssm_dstate_stride = ssm_states.stride(2) + cache_indices_stride = cache_indices.stride(0) if has_cache_indices else 0 + + grid = (batch_size, dim) + _selective_scan_fwd_kernel[grid]( + u, + delta, + A, + B, + C, + D_ if has_D else u, # dummy, won't be dereferenced + z_ if has_z else u, # dummy + delta_bias_ if has_delta_bias else u, # dummy + out, + out_z, + ssm_states, + query_start_loc if varlen else u, # dummy + cache_indices if has_cache_indices else u, # dummy + has_initial_state, + # APC pointers + block_idx_first_scheduled_token if cache_enabled else u, + block_idx_last_scheduled_token if cache_enabled else u, + initial_state_idx if cache_enabled else u, + cu_chunk_seqlen if cache_enabled else u, + last_chunk_indices if cache_enabled else u, + # Dimensions + batch_size, + dim, + total_seqlen, + dstate, + n_groups, + dim_ngroups_ratio, + # Strides + u_batch_stride, + u_d_stride, + delta_batch_stride, + delta_d_stride, + A.stride(0), + A.stride(1), + B_batch_stride, + B_group_stride, + B_dstate_stride, + C_batch_stride, + C_group_stride, + C_dstate_stride, + z_batch_stride, + z_d_stride, + out_batch_stride, + out_d_stride, + out_z_batch_stride, + out_z_d_stride, + ssm_batch_stride, + ssm_dim_stride, + ssm_dstate_stride, + cache_indices_stride, + null_block_id, + block_size, + # Compile-time constants + delta_softplus=delta_softplus, + HAS_D=has_D, + HAS_Z=has_z, + HAS_DELTA_BIAS=has_delta_bias, + IS_VARLEN=varlen, + HAS_CACHE_INDICES=has_cache_indices, + CACHE_ENABLED=cache_enabled, + BLOCK_DSTATE=BLOCK_DSTATE, + ) From 8c063dd76eb5c774e63c0e0015f258ac330e392c Mon Sep 17 00:00:00 2001 From: Marceli Fylcek Date: Tue, 19 May 2026 18:24:16 +0300 Subject: [PATCH 2/5] Move to _xpu_ops Signed-off-by: Marceli Fylcek --- vllm/_custom_ops.py | 31 -- vllm/_xpu_ops.py | 486 ++++++++++++++++++ .../layers/mamba/ops/mamba_ssm.py | 72 ++- 3 files changed, 536 insertions(+), 53 deletions(-) diff --git a/vllm/_custom_ops.py b/vllm/_custom_ops.py index 262a62aa38cb..cb4bd28f2677 100644 --- a/vllm/_custom_ops.py +++ b/vllm/_custom_ops.py @@ -2184,37 +2184,6 @@ def selective_scan_fwd( cu_chunk_seqlen: torch.Tensor | None = None, last_chunk_indices: torch.Tensor | None = None, ): - from vllm.platforms import current_platform - - if current_platform.is_xpu(): - from vllm.model_executor.layers.mamba.ops.selective_scan_triton import ( - selective_scan_fwd_triton, - ) - - selective_scan_fwd_triton( - u, - delta, - A, - B, - C, - D_, - z_, - delta_bias_, - delta_softplus, - query_start_loc, - cache_indices, - has_initial_state, - ssm_states, - null_block_id, - block_size, - block_idx_first_scheduled_token, - block_idx_last_scheduled_token, - initial_state_idx, - cu_chunk_seqlen, - last_chunk_indices, - ) - return - torch.ops._C.selective_scan_fwd( u, delta, diff --git a/vllm/_xpu_ops.py b/vllm/_xpu_ops.py index 6a2e5e841ce1..01602e9e94d1 100644 --- a/vllm/_xpu_ops.py +++ b/vllm/_xpu_ops.py @@ -8,6 +8,7 @@ from vllm.logger import init_logger from vllm.platforms import current_platform +from vllm.triton_utils import tl, triton from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) @@ -296,6 +297,326 @@ def _xpu_mxfp4_quantize_fake( return x_q, x_s +@triton.jit +def _softplus(x): + return tl.where(x <= 20.0, tl.math.log(tl.math.exp(x) + 1.0), x) + + +@triton.jit +def _selective_scan_fwd_kernel( + # Pointers to input tensors + u_ptr, + delta_ptr, + A_ptr, + B_ptr, + C_ptr, + D_ptr, + z_ptr, + delta_bias_ptr, + # Pointers to output tensors (out aliases delta, out_z aliases z) + out_ptr, + out_z_ptr, + # SSM states + ssm_states_ptr, + # Optional pointers + query_start_loc_ptr, + cache_indices_ptr, + has_initial_state_ptr, + # APC pointers + block_idx_first_ptr, + block_idx_last_ptr, + initial_state_idx_ptr, + cu_chunk_seqlen_ptr, + last_chunk_indices_ptr, + # Dimensions + batch: tl.int32, + dim: tl.int32, + seqlen: tl.int32, + dstate: tl.int32, + n_groups: tl.int32, + dim_ngroups_ratio: tl.int32, + # Strides for u (and out, since out = delta which has same layout) + u_batch_stride: tl.int64, + u_d_stride: tl.int64, + # Strides for delta + delta_batch_stride: tl.int64, + delta_d_stride: tl.int64, + # Strides for A + A_d_stride: tl.int64, + A_dstate_stride: tl.int64, + # Strides for B + B_batch_stride: tl.int64, + B_group_stride: tl.int64, + B_dstate_stride: tl.int64, + # Strides for C + C_batch_stride: tl.int64, + C_group_stride: tl.int64, + C_dstate_stride: tl.int64, + # Strides for z + z_batch_stride: tl.int64, + z_d_stride: tl.int64, + # Strides for out + out_batch_stride: tl.int64, + out_d_stride: tl.int64, + # Strides for out_z + out_z_batch_stride: tl.int64, + out_z_d_stride: tl.int64, + # Strides for ssm_states + ssm_batch_stride: tl.int64, + ssm_dim_stride: tl.int64, + ssm_dstate_stride: tl.int64, + # Cache strides + cache_indices_stride: tl.int64, + # Scalar params + null_block_id: tl.int64, + block_size: tl.int32, + # Compile-time constants + delta_softplus: tl.constexpr, + HAS_D: tl.constexpr, + HAS_Z: tl.constexpr, + HAS_DELTA_BIAS: tl.constexpr, + IS_VARLEN: tl.constexpr, + HAS_CACHE_INDICES: tl.constexpr, + CACHE_ENABLED: tl.constexpr, + BLOCK_DSTATE: tl.constexpr, +): + batch_idx = tl.program_id(0) + dim_idx = tl.program_id(1) + group_idx = dim_idx // dim_ngroups_ratio + + # Determine sequence boundaries + if IS_VARLEN: + seq_start = tl.load(query_start_loc_ptr + batch_idx).to(tl.int32) + seq_end = tl.load(query_start_loc_ptr + batch_idx + 1).to(tl.int32) + actual_seqlen = seq_end - seq_start + else: + seq_start = 0 + actual_seqlen = seqlen + + # Determine cache index for ssm_states + if CACHE_ENABLED: + init_state_idx = tl.load(initial_state_idx_ptr + batch_idx).to(tl.int32) + load_cache_slot = tl.load( + cache_indices_ptr + batch_idx * cache_indices_stride + init_state_idx + ).to(tl.int64) + if load_cache_slot == null_block_id: + return + elif HAS_CACHE_INDICES: + cache_index = tl.load(cache_indices_ptr + batch_idx).to(tl.int64) + if cache_index == null_block_id: + return + load_cache_slot = cache_index + else: + load_cache_slot = batch_idx.to(tl.int64) + + # Load D value + D_val = 0.0 + if HAS_D: + D_val = tl.load(D_ptr + dim_idx).to(tl.float32) + + # Load delta_bias value + delta_bias_val = 0.0 + if HAS_DELTA_BIAS: + delta_bias_val = tl.load(delta_bias_ptr + dim_idx).to(tl.float32) + + # Load A values for this dim - shape (dstate,) + dstate_offs = tl.arange(0, BLOCK_DSTATE) + dstate_mask = dstate_offs < dstate + A_vals = tl.load( + A_ptr + dim_idx * A_d_stride + dstate_offs * A_dstate_stride, + mask=dstate_mask, + other=0.0, + ).to(tl.float32) + + # Initialize state vector + state = tl.zeros((BLOCK_DSTATE,), dtype=tl.float32) + + # Load initial state if available + has_init = False + if has_initial_state_ptr is not None: + has_init = tl.load(has_initial_state_ptr + batch_idx) + if has_init: + state = tl.load( + ssm_states_ptr + + load_cache_slot * ssm_batch_stride + + dim_idx * ssm_dim_stride + + dstate_offs * ssm_dstate_stride, + mask=dstate_mask, + other=0.0, + ).to(tl.float32) + + # Compute base addresses for u and delta + if IS_VARLEN: + u_base = u_ptr + dim_idx * u_d_stride + seq_start * u_batch_stride + delta_base = ( + delta_ptr + dim_idx * delta_d_stride + seq_start * delta_batch_stride + ) + out_base = ( + out_ptr + dim_idx * out_d_stride + seq_start * out_batch_stride + ) + B_base = B_ptr + group_idx * B_group_stride + seq_start * B_batch_stride + C_base = C_ptr + group_idx * C_group_stride + seq_start * C_batch_stride + else: + u_base = u_ptr + batch_idx * u_batch_stride + dim_idx * u_d_stride + delta_base = ( + delta_ptr + batch_idx * delta_batch_stride + dim_idx * delta_d_stride + ) + out_base = ( + out_ptr + batch_idx * out_batch_stride + dim_idx * out_d_stride + ) + B_base = B_ptr + batch_idx * B_batch_stride + group_idx * B_group_stride + C_base = C_ptr + batch_idx * C_batch_stride + group_idx * C_group_stride + + if HAS_Z: + if IS_VARLEN: + z_base = z_ptr + dim_idx * z_d_stride + seq_start * z_batch_stride + out_z_base = ( + out_z_ptr + + dim_idx * out_z_d_stride + + seq_start * out_z_batch_stride + ) + else: + z_base = z_ptr + batch_idx * z_batch_stride + dim_idx * z_d_stride + out_z_base = ( + out_z_ptr + + batch_idx * out_z_batch_stride + + dim_idx * out_z_d_stride + ) + + # Determine chunk boundaries for APC mode + if CACHE_ENABLED: + last_chunk_idx = tl.load(last_chunk_indices_ptr + batch_idx).to(tl.int32) + if batch_idx == 0: + first_chunk_idx = 0 + else: + first_chunk_idx = ( + tl.load(last_chunk_indices_ptr + batch_idx - 1).to(tl.int32) + 1 + ) + n_chunks = last_chunk_idx - first_chunk_idx + 1 + first_chunk_tokens = ( + tl.load(cu_chunk_seqlen_ptr + first_chunk_idx + 1).to(tl.int32) + - tl.load(cu_chunk_seqlen_ptr + first_chunk_idx).to(tl.int32) + ) + block_idx_first = tl.load(block_idx_first_ptr + batch_idx).to(tl.int32) + chunk_start_offset = 0 + if n_chunks > 1 and first_chunk_tokens < block_size: + chunk_start_offset = block_size - first_chunk_tokens + current_position = block_idx_first * block_size + chunk_start_offset + else: + n_chunks = 1 + first_chunk_idx = 0 + + # Sequential scan over the sequence + tokens_processed = 0 + for chunk in range(0, n_chunks if CACHE_ENABLED else 1): + if CACHE_ENABLED: + chunk_tokens = ( + tl.load( + cu_chunk_seqlen_ptr + first_chunk_idx + chunk + 1 + ).to(tl.int32) + - tl.load( + cu_chunk_seqlen_ptr + first_chunk_idx + chunk + ).to(tl.int32) + ) + else: + chunk_tokens = actual_seqlen + + for local_pos in range(chunk_tokens): + pos = tokens_processed + local_pos + # Load u value + u_val = tl.load(u_base + pos).to(tl.float32) + + # Load delta value + delta_val = tl.load(delta_base + pos).to(tl.float32) + + # Apply delta bias + if HAS_DELTA_BIAS: + delta_val = delta_val + delta_bias_val + + # Apply softplus + if delta_softplus: + delta_val = _softplus(delta_val) + + delta_u = delta_val * u_val + + # Compute dA = exp(delta * A) for all dstate elements + dA = tl.exp(delta_val * A_vals) + + # Load B values for this position + B_vals = tl.load( + B_base + dstate_offs * B_dstate_stride + pos, + mask=dstate_mask, + other=0.0, + ).to(tl.float32) + + # Load C values for this position + C_vals = tl.load( + C_base + dstate_offs * C_dstate_stride + pos, + mask=dstate_mask, + other=0.0, + ).to(tl.float32) + + # Update state: state = dA * state + delta * u * B + state = dA * state + delta_u * B_vals + + # Compute output: out = sum(state * C) + D * u + out_val = tl.sum(state * C_vals, axis=0) + if HAS_D: + out_val = out_val + D_val * u_val + + # Store output + tl.store(out_base + pos, out_val.to(out_ptr.dtype.element_ty)) + + if HAS_Z: + z_val = tl.load(z_base + pos).to(tl.float32) + out_z_val = out_val * z_val / (1.0 + tl.exp(-z_val)) + tl.store( + out_z_base + pos, + out_z_val.to(out_z_ptr.dtype.element_ty), + ) + + tokens_processed += chunk_tokens + + # Store intermediate state for APC mode + if CACHE_ENABLED: + if chunk == n_chunks - 1: + store_slot = tl.load( + cache_indices_ptr + + batch_idx * cache_indices_stride + + tl.load(block_idx_last_ptr + batch_idx).to(tl.int32) + ).to(tl.int64) + else: + block_idx_done = ( + current_position + chunk_tokens - 1 + ) // block_size + store_slot = tl.load( + cache_indices_ptr + + batch_idx * cache_indices_stride + + block_idx_done + ).to(tl.int64) + + tl.store( + ssm_states_ptr + + store_slot * ssm_batch_stride + + dim_idx * ssm_dim_stride + + dstate_offs * ssm_dstate_stride, + state.to(ssm_states_ptr.dtype.element_ty), + mask=dstate_mask, + ) + current_position += chunk_tokens + + # Store final state for non-APC mode + if not CACHE_ENABLED: + tl.store( + ssm_states_ptr + + load_cache_slot * ssm_batch_stride + + dim_idx * ssm_dim_stride + + dstate_offs * ssm_dstate_stride, + state.to(ssm_states_ptr.dtype.element_ty), + mask=dstate_mask, + ) + + # Global flag to ensure ops are registered only once _OPS_REGISTERED = False @@ -621,6 +942,171 @@ def cp_gather_indexer_k_quant_cache( ) dst_scale[:] = kv_cache_flat[scale_indices] + @staticmethod + def selective_scan_fwd( + u: torch.Tensor, + delta: torch.Tensor, + A: torch.Tensor, + B: torch.Tensor, + C: torch.Tensor, + D_: torch.Tensor | None, + z_: torch.Tensor | None, + delta_bias_: torch.Tensor | None, + delta_softplus: bool, + query_start_loc: torch.Tensor | None, + cache_indices: torch.Tensor | None, + has_initial_state: torch.Tensor | None, + ssm_states: torch.Tensor, + null_block_id: int, + block_size: int = 1024, + block_idx_first_scheduled_token: torch.Tensor | None = None, + block_idx_last_scheduled_token: torch.Tensor | None = None, + initial_state_idx: torch.Tensor | None = None, + cu_chunk_seqlen: torch.Tensor | None = None, + last_chunk_indices: torch.Tensor | None = None, + ) -> None: + varlen = query_start_loc is not None + batch_size = ( + (query_start_loc.shape[0] - 1) if varlen else u.shape[0] + ) + dim = u.shape[0] if varlen else u.shape[1] + total_seqlen = u.shape[1] if varlen else u.shape[2] + dstate = A.size(1) + n_groups = B.size(0) if varlen else B.size(1) + dim_ngroups_ratio = dim // n_groups + + has_z = z_ is not None + has_D = D_ is not None + has_delta_bias = delta_bias_ is not None + has_cache_indices = cache_indices is not None + cache_enabled = block_idx_first_scheduled_token is not None + + # out and out_z alias delta and z respectively + out = delta + out_z = z_ if has_z else delta # dummy, won't be used if not has_z + + BLOCK_DSTATE = triton.next_power_of_2(dstate) + + # Compute strides + if varlen: + u_batch_stride = u.stride(1) + u_d_stride = u.stride(0) + delta_batch_stride = delta.stride(1) + delta_d_stride = delta.stride(0) + B_batch_stride = B.stride(2) + B_group_stride = B.stride(0) + B_dstate_stride = B.stride(1) + C_batch_stride = C.stride(2) + C_group_stride = C.stride(0) + C_dstate_stride = C.stride(1) + out_batch_stride = out.stride(1) + out_d_stride = out.stride(0) + if has_z: + z_batch_stride = z_.stride(1) + z_d_stride = z_.stride(0) + out_z_batch_stride = out_z.stride(1) + out_z_d_stride = out_z.stride(0) + else: + z_batch_stride = 0 + z_d_stride = 0 + out_z_batch_stride = 0 + out_z_d_stride = 0 + else: + u_batch_stride = u.stride(0) + u_d_stride = u.stride(1) + delta_batch_stride = delta.stride(0) + delta_d_stride = delta.stride(1) + B_batch_stride = B.stride(0) + B_group_stride = B.stride(1) + B_dstate_stride = B.stride(2) + C_batch_stride = C.stride(0) + C_group_stride = C.stride(1) + C_dstate_stride = C.stride(2) + out_batch_stride = out.stride(0) + out_d_stride = out.stride(1) + if has_z: + z_batch_stride = z_.stride(0) + z_d_stride = z_.stride(1) + out_z_batch_stride = out_z.stride(0) + out_z_d_stride = out_z.stride(1) + else: + z_batch_stride = 0 + z_d_stride = 0 + out_z_batch_stride = 0 + out_z_d_stride = 0 + + ssm_batch_stride = ssm_states.stride(0) + ssm_dim_stride = ssm_states.stride(1) + ssm_dstate_stride = ssm_states.stride(2) + cache_indices_stride = ( + cache_indices.stride(0) if has_cache_indices else 0 + ) + + grid = (batch_size, dim) + _selective_scan_fwd_kernel[grid]( + u, + delta, + A, + B, + C, + D_ if has_D else u, # dummy, won't be dereferenced + z_ if has_z else u, # dummy + delta_bias_ if has_delta_bias else u, # dummy + out, + out_z, + ssm_states, + query_start_loc if varlen else u, # dummy + cache_indices if has_cache_indices else u, # dummy + has_initial_state, + # APC pointers + block_idx_first_scheduled_token if cache_enabled else u, + block_idx_last_scheduled_token if cache_enabled else u, + initial_state_idx if cache_enabled else u, + cu_chunk_seqlen if cache_enabled else u, + last_chunk_indices if cache_enabled else u, + # Dimensions + batch_size, + dim, + total_seqlen, + dstate, + n_groups, + dim_ngroups_ratio, + # Strides + u_batch_stride, + u_d_stride, + delta_batch_stride, + delta_d_stride, + A.stride(0), + A.stride(1), + B_batch_stride, + B_group_stride, + B_dstate_stride, + C_batch_stride, + C_group_stride, + C_dstate_stride, + z_batch_stride, + z_d_stride, + out_batch_stride, + out_d_stride, + out_z_batch_stride, + out_z_d_stride, + ssm_batch_stride, + ssm_dim_stride, + ssm_dstate_stride, + cache_indices_stride, + null_block_id, + block_size, + # Compile-time constants + delta_softplus=delta_softplus, + HAS_D=has_D, + HAS_Z=has_z, + HAS_DELTA_BIAS=has_delta_bias, + IS_VARLEN=varlen, + HAS_CACHE_INDICES=has_cache_indices, + CACHE_ENABLED=cache_enabled, + BLOCK_DSTATE=BLOCK_DSTATE, + ) + @staticmethod def top_k_per_row_prefill( logits: torch.Tensor, diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index e3c8ba8312f2..db8ac74a6ea0 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -9,9 +9,13 @@ from vllm import _custom_ops as ops from vllm.model_executor.layers.mamba.ops.triton_helpers import fast_exp +from vllm.platforms import current_platform from vllm.triton_utils import HAS_TRITON, tl, triton from vllm.v1.attention.backends.utils import NULL_BLOCK_ID +if current_platform.is_xpu(): + from vllm._xpu_ops import xpu_ops + TRITON3 = HAS_TRITON and (version.parse(triton.__version__) >= version.parse("3.0.0")) if TRITON3: @@ -629,28 +633,52 @@ def selective_scan_fn( if C.dim() == 2 and query_start_loc is not None: C = C.unsqueeze(0) - ops.selective_scan_fwd( - u, - delta, - A, - B, - C, - D, - z, - delta_bias, - delta_softplus, - query_start_loc, - cache_indices, - has_initial_state, - ssm_states, - null_block_id, - block_size, - block_idx_first_scheduled_token, - block_idx_last_scheduled_token, - initial_state_idx, - cu_chunk_seqlen, - last_chunk_indices, - ) + if current_platform.is_xpu(): + xpu_ops.selective_scan_fwd( + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + query_start_loc, + cache_indices, + has_initial_state, + ssm_states, + null_block_id, + block_size, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + initial_state_idx, + cu_chunk_seqlen, + last_chunk_indices, + ) + else: + ops.selective_scan_fwd( + u, + delta, + A, + B, + C, + D, + z, + delta_bias, + delta_softplus, + query_start_loc, + cache_indices, + has_initial_state, + ssm_states, + null_block_id, + block_size, + block_idx_first_scheduled_token, + block_idx_last_scheduled_token, + initial_state_idx, + cu_chunk_seqlen, + last_chunk_indices, + ) if z is None: return delta # output written inplace to delta From a1f7ca70d9419960365ab61df1c6483daf03ed01 Mon Sep 17 00:00:00 2001 From: Marceli Fylcek Date: Thu, 21 May 2026 15:17:32 +0300 Subject: [PATCH 3/5] Remove old Signed-off-by: Marceli Fylcek --- .../mamba/test_selective_scan_triton.py | 378 ------------- .../layers/mamba/ops/selective_scan_triton.py | 510 ------------------ 2 files changed, 888 deletions(-) delete mode 100644 tests/kernels/mamba/test_selective_scan_triton.py delete mode 100644 vllm/model_executor/layers/mamba/ops/selective_scan_triton.py diff --git a/tests/kernels/mamba/test_selective_scan_triton.py b/tests/kernels/mamba/test_selective_scan_triton.py deleted file mode 100644 index 69890c3635db..000000000000 --- a/tests/kernels/mamba/test_selective_scan_triton.py +++ /dev/null @@ -1,378 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -"""Test script for the Triton selective scan implementation on XPU.""" - -import torch -import torch.nn.functional as F - -from vllm.model_executor.layers.mamba.ops.selective_scan_triton import ( - selective_scan_fwd_triton, -) - - -def selective_scan_ref( - u, delta, A, B, C, D=None, z=None, delta_bias=None, - delta_softplus=False, prev_state=None, -): - """Reference implementation (pure PyTorch, sequential scan).""" - dtype_in = u.dtype - u = u.float() - delta = delta.float() - if delta_bias is not None: - delta = delta + delta_bias[..., None].float() - if delta_softplus: - delta = F.softplus(delta) - batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1] - B = B.float() - C = C.float() - x = A.new_zeros((batch, dim, dstate)) if prev_state is None else prev_state.float() - ys = [] - deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A)) - if B.dim() == 3: - deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u) - else: - from einops import repeat - B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1]) - deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u) - if C.dim() == 4: - from einops import repeat - C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1]) - last_state = None - for i in range(u.shape[2]): - x = deltaA[:, :, i] * x + deltaB_u[:, :, i] - if C.dim() == 3: - y = torch.einsum("bdn,bn->bd", x, C[:, :, i]) - else: - y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i]) - if i == u.shape[2] - 1: - last_state = x - ys.append(y) - y = torch.stack(ys, dim=2) - out = y if D is None else y + u * D[None, :, None] - if z is not None: - out = out * F.silu(z.float()) - out = out.to(dtype=dtype_in) - return out, last_state - - -def test_basic(device, itype, seqlen, has_z, has_D, has_delta_bias, - delta_softplus, varBC_groups, batch_size=1, dim=4, dstate=8): - """Test basic selective scan correctness.""" - torch.manual_seed(42) - wtype = torch.float32 - - rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) - if itype == torch.bfloat16: - rtol, atol = 3e-2, 5e-2 - - A = -0.5 * torch.rand(dim, dstate, device=device, dtype=wtype) - B_shape = [batch_size, varBC_groups, dstate, seqlen] - B = torch.randn(B_shape, device=device, dtype=itype) - C_shape = [batch_size, varBC_groups, dstate, seqlen] - C = torch.randn(C_shape, device=device, dtype=itype) - D = torch.randn(dim, device=device, dtype=torch.float32) if has_D else None - z = (torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) - if has_z else None) - delta_bias = (0.5 * torch.rand(dim, device=device, dtype=torch.float32) - if has_delta_bias else None) - u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) - delta = 0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype) - ssm_states = torch.zeros(batch_size, dim, dstate, device=device, dtype=itype) - - # Reference - u_ref = u.clone() - delta_ref = delta.clone() - z_ref = z.clone() if z is not None else None - out_ref, state_ref = selective_scan_ref( - u_ref, delta_ref, A.clone(), B.clone(), C.clone(), - D=D.clone() if D is not None else None, - z=z_ref, - delta_bias=delta_bias, - delta_softplus=delta_softplus, - ) - - # Prepare inputs for our kernel (need contiguous, matching expected shapes) - u_test = u.clone() - delta_test = delta.clone() - z_test = z.clone() if z is not None else None - B_test = B.clone() - C_test = C.clone() - ssm_states_test = ssm_states.clone() - - selective_scan_fwd_triton( - u_test, delta_test, A, B_test, C_test, - D, z_test, delta_bias, - delta_softplus, - query_start_loc=None, - cache_indices=None, - has_initial_state=None, - ssm_states=ssm_states_test, - null_block_id=-1, - block_size=2048, - ) - - # Get output (z if has_z, else delta) - out_test = z_test if has_z else delta_test - - # Compare - if not torch.allclose(out_test, out_ref, rtol=rtol, atol=atol): - max_diff = (out_test - out_ref).abs().max().item() - print(f" FAIL: max_diff={max_diff:.6e} (rtol={rtol}, atol={atol})") - return False - - if not torch.allclose(ssm_states_test, state_ref.to(itype), rtol=rtol, atol=atol): - max_diff = (ssm_states_test - state_ref.to(itype)).abs().max().item() - print(f" FAIL (states): max_diff={max_diff:.6e}") - return False - - return True - - -def test_with_initial_state(device, itype, seqlen): - """Test with initial state (chunked scan simulation).""" - torch.manual_seed(42) - dim, dstate, batch_size = 4, 8, 1 - varBC_groups = 1 - - rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) - if itype == torch.bfloat16: - rtol, atol = 3e-2, 5e-2 - - A = -0.5 * torch.rand(dim, dstate, device=device, dtype=torch.float32) - B = torch.randn(batch_size, varBC_groups, dstate, seqlen, device=device, dtype=itype) - C = torch.randn(batch_size, varBC_groups, dstate, seqlen, device=device, dtype=itype) - D = torch.randn(dim, device=device, dtype=torch.float32) - z = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) - delta_bias = 0.5 * torch.rand(dim, device=device, dtype=torch.float32) - u = torch.randn(batch_size, dim, seqlen, device=device, dtype=itype) - delta = 0.5 * torch.rand(batch_size, dim, seqlen, device=device, dtype=itype) - - # Reference: full scan - out_ref, state_ref = selective_scan_ref( - u.clone(), delta.clone(), A.clone(), B.clone(), C.clone(), - D=D.clone(), z=z.clone(), delta_bias=delta_bias, delta_softplus=True, - ) - - # Triton: two-chunk scan - mid = seqlen // 2 - # First chunk - ssm_states = torch.zeros(batch_size, dim, dstate, device=device, dtype=itype) - delta1 = delta[..., :mid].clone() - z1 = z[..., :mid].clone() - selective_scan_fwd_triton( - u[..., :mid].contiguous(), delta1, A, B[..., :mid].contiguous(), - C[..., :mid].contiguous(), D, z1, delta_bias, True, - None, None, None, ssm_states, -1, 2048, - ) - out1 = z1 - - # Second chunk with initial state - delta2 = delta[..., mid:].clone() - z2 = z[..., mid:].clone() - selective_scan_fwd_triton( - u[..., mid:].contiguous(), delta2, A, B[..., mid:].contiguous(), - C[..., mid:].contiguous(), D, z2, delta_bias, True, - None, None, - torch.ones(batch_size, device=device, dtype=torch.bool), - ssm_states, -1, 2048, - ) - out2 = z2 - - out_test = torch.cat([out1, out2], dim=-1) - - if not torch.allclose(out_test, out_ref, rtol=rtol, atol=atol): - max_diff = (out_test - out_ref).abs().max().item() - print(f" FAIL: max_diff={max_diff:.6e}") - return False - - if not torch.allclose(ssm_states, state_ref.to(itype), rtol=rtol, atol=atol): - max_diff = (ssm_states - state_ref.to(itype)).abs().max().item() - print(f" FAIL (states): max_diff={max_diff:.6e}") - return False - - return True - - -def test_varlen(device, itype, seqlens): - """Test varlen mode with query_start_loc.""" - torch.manual_seed(42) - dim, dstate = 4, 8 - n_groups = 1 - batch_size = len(seqlens) - total_len = sum(seqlens) - - rtol, atol = (6e-4, 2e-3) if itype == torch.float32 else (3e-3, 5e-3) - if itype == torch.bfloat16: - rtol, atol = 3e-2, 5e-2 - - A = -0.5 * torch.rand(dim, dstate, device=device, dtype=torch.float32) - D = torch.randn(dim, device=device, dtype=torch.float32) - delta_bias = 0.5 * torch.rand(dim, device=device, dtype=torch.float32) - - # Varlen layout: u=(dim, total_len), B=(n_groups, dstate, total_len) - u = torch.randn(dim, total_len, device=device, dtype=itype) - delta = 0.5 * torch.rand(dim, total_len, device=device, dtype=itype) - z = torch.randn(dim, total_len, device=device, dtype=itype) - B = torch.randn(n_groups, dstate, total_len, device=device, dtype=itype) - C = torch.randn(n_groups, dstate, total_len, device=device, dtype=itype) - ssm_states = torch.zeros(batch_size, dim, dstate, device=device, dtype=itype) - - query_start_loc = torch.zeros(batch_size + 1, device=device, dtype=torch.int32) - for i, sl in enumerate(seqlens): - query_start_loc[i + 1] = query_start_loc[i] + sl - - # Reference: process each sequence separately in batch mode - out_ref_parts = [] - state_refs = [] - offset = 0 - for i, sl in enumerate(seqlens): - u_i = u[:, offset:offset + sl].unsqueeze(0) # (1, dim, sl) - delta_i = delta[:, offset:offset + sl].unsqueeze(0) - z_i = z[:, offset:offset + sl].unsqueeze(0) - B_i = B[:, :, offset:offset + sl].unsqueeze(0) # (1, n_groups, dstate, sl) - C_i = C[:, :, offset:offset + sl].unsqueeze(0) - out_i, state_i = selective_scan_ref( - u_i, delta_i, A.clone(), B_i, C_i, - D=D.clone(), z=z_i, delta_bias=delta_bias, delta_softplus=True, - ) - out_ref_parts.append(out_i.squeeze(0)) # (dim, sl) - state_refs.append(state_i.squeeze(0)) - offset += sl - out_ref = torch.cat(out_ref_parts, dim=-1) # (dim, total_len) - state_ref = torch.stack(state_refs, dim=0) # (batch, dim, dstate) - - # Triton varlen - delta_test = delta.clone() - z_test = z.clone() - ssm_states_test = ssm_states.clone() - selective_scan_fwd_triton( - u.clone(), delta_test, A, B.clone(), C.clone(), - D, z_test, delta_bias, True, - query_start_loc, None, None, ssm_states_test, -1, 2048, - ) - out_test = z_test # z is present - - if not torch.allclose(out_test, out_ref, rtol=rtol, atol=atol): - max_diff = (out_test - out_ref).abs().max().item() - print(f" FAIL: max_diff={max_diff:.6e}") - return False - - if not torch.allclose(ssm_states_test, state_ref.to(itype), rtol=rtol, atol=atol): - max_diff = (ssm_states_test - state_ref.to(itype)).abs().max().item() - print(f" FAIL (states): max_diff={max_diff:.6e}") - return False - - return True - - -def main(): - device = "xpu:0" - print(f"Testing on {device}") - print(f"Device: {torch.xpu.get_device_name(0)}") - print() - - total = 0 - passed = 0 - - # Test 1: Basic configurations - print("=== Basic tests ===") - for itype in [torch.float32, torch.bfloat16]: - for seqlen in [16, 128, 512]: - for has_z in [True, False]: - for has_D in [True, False]: - for varBC_groups in [1, 2]: - total += 1 - name = (f"basic itype={itype}, seqlen={seqlen}, " - f"z={has_z}, D={has_D}, groups={varBC_groups}") - try: - ok = test_basic( - device, itype, seqlen, - has_z=has_z, has_D=has_D, - has_delta_bias=True, - delta_softplus=True, - varBC_groups=varBC_groups, - ) - if ok: - passed += 1 - print(f" PASS: {name}") - else: - print(f" FAIL: {name}") - except Exception as e: - print(f" ERROR: {name}: {e}") - import traceback - traceback.print_exc() - - # Test 2: Chunked scan with initial state - print("\n=== Chunked scan (initial state) tests ===") - for itype in [torch.float32, torch.bfloat16]: - for seqlen in [64, 256]: - total += 1 - name = f"chunked itype={itype}, seqlen={seqlen}" - try: - ok = test_with_initial_state(device, itype, seqlen) - if ok: - passed += 1 - print(f" PASS: {name}") - else: - print(f" FAIL: {name}") - except Exception as e: - print(f" ERROR: {name}: {e}") - import traceback - traceback.print_exc() - - # Test 3: Varlen mode - print("\n=== Varlen tests ===") - for itype in [torch.float32, torch.bfloat16]: - for seqlens in [[32, 16], [64, 32, 16]]: - total += 1 - name = f"varlen itype={itype}, seqlens={seqlens}" - try: - ok = test_varlen(device, itype, seqlens) - if ok: - passed += 1 - print(f" PASS: {name}") - else: - print(f" FAIL: {name}") - except Exception as e: - print(f" ERROR: {name}: {e}") - import traceback - traceback.print_exc() - - # Test 4: Larger dimensions (closer to real model) - print("\n=== Large dimension tests ===") - for itype in [torch.bfloat16]: - for dim in [256, 1024]: - for seqlen in [128, 512]: - total += 1 - name = f"large dim={dim}, seqlen={seqlen}, itype={itype}" - try: - ok = test_basic( - device, itype, seqlen, - has_z=True, has_D=True, - has_delta_bias=True, - delta_softplus=True, - varBC_groups=1, - dim=dim, dstate=16, - ) - if ok: - passed += 1 - print(f" PASS: {name}") - else: - print(f" FAIL: {name}") - except Exception as e: - print(f" ERROR: {name}: {e}") - import traceback - traceback.print_exc() - - print(f"\n{'='*60}") - print(f"Results: {passed}/{total} tests passed") - if passed == total: - print("ALL TESTS PASSED!") - else: - print(f"FAILURES: {total - passed}") - - return passed == total - - -if __name__ == "__main__": - success = main() - exit(0 if success else 1) diff --git a/vllm/model_executor/layers/mamba/ops/selective_scan_triton.py b/vllm/model_executor/layers/mamba/ops/selective_scan_triton.py deleted file mode 100644 index a13f7498e11d..000000000000 --- a/vllm/model_executor/layers/mamba/ops/selective_scan_triton.py +++ /dev/null @@ -1,510 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -""" -Triton implementation of the Mamba selective scan forward pass. - -This provides a platform-portable alternative to the CUDA-only -selective_scan_fwd kernel. It supports both varlen and non-varlen modes, -with optional z-gating, D bias, delta bias, and delta softplus. - -The kernel uses a sequential scan approach: each program handles one -(batch, dim) pair and iterates over the sequence length, maintaining -the SSM state vector across positions. Parallelism comes from launching -batch * dim programs concurrently. -""" - -import torch - -from vllm.triton_utils import tl, triton - - -@triton.jit -def _softplus(x): - return tl.where(x <= 20.0, tl.math.log(tl.math.exp(x) + 1.0), x) - - -@triton.jit -def _selective_scan_fwd_kernel( - # Pointers to input tensors - u_ptr, - delta_ptr, - A_ptr, - B_ptr, - C_ptr, - D_ptr, - z_ptr, - delta_bias_ptr, - # Pointers to output tensors (out aliases delta, out_z aliases z) - out_ptr, - out_z_ptr, - # SSM states - ssm_states_ptr, - # Optional pointers - query_start_loc_ptr, - cache_indices_ptr, - has_initial_state_ptr, - # APC pointers - block_idx_first_ptr, - block_idx_last_ptr, - initial_state_idx_ptr, - cu_chunk_seqlen_ptr, - last_chunk_indices_ptr, - # Dimensions - batch: tl.int32, - dim: tl.int32, - seqlen: tl.int32, - dstate: tl.int32, - n_groups: tl.int32, - dim_ngroups_ratio: tl.int32, - # Strides for u (and out, since out = delta which has same layout) - u_batch_stride: tl.int64, - u_d_stride: tl.int64, - # Strides for delta - delta_batch_stride: tl.int64, - delta_d_stride: tl.int64, - # Strides for A - A_d_stride: tl.int64, - A_dstate_stride: tl.int64, - # Strides for B - B_batch_stride: tl.int64, - B_group_stride: tl.int64, - B_dstate_stride: tl.int64, - # Strides for C - C_batch_stride: tl.int64, - C_group_stride: tl.int64, - C_dstate_stride: tl.int64, - # Strides for z - z_batch_stride: tl.int64, - z_d_stride: tl.int64, - # Strides for out - out_batch_stride: tl.int64, - out_d_stride: tl.int64, - # Strides for out_z - out_z_batch_stride: tl.int64, - out_z_d_stride: tl.int64, - # Strides for ssm_states - ssm_batch_stride: tl.int64, - ssm_dim_stride: tl.int64, - ssm_dstate_stride: tl.int64, - # Cache strides - cache_indices_stride: tl.int64, - # Scalar params - null_block_id: tl.int64, - block_size: tl.int32, - # Compile-time constants - delta_softplus: tl.constexpr, - HAS_D: tl.constexpr, - HAS_Z: tl.constexpr, - HAS_DELTA_BIAS: tl.constexpr, - IS_VARLEN: tl.constexpr, - HAS_CACHE_INDICES: tl.constexpr, - CACHE_ENABLED: tl.constexpr, - BLOCK_DSTATE: tl.constexpr, -): - batch_idx = tl.program_id(0) - dim_idx = tl.program_id(1) - group_idx = dim_idx // dim_ngroups_ratio - - # Determine sequence boundaries - if IS_VARLEN: - seq_start = tl.load(query_start_loc_ptr + batch_idx).to(tl.int32) - seq_end = tl.load(query_start_loc_ptr + batch_idx + 1).to(tl.int32) - actual_seqlen = seq_end - seq_start - else: - seq_start = 0 - actual_seqlen = seqlen - - # Determine cache index for ssm_states - if CACHE_ENABLED: - init_state_idx = tl.load(initial_state_idx_ptr + batch_idx).to(tl.int32) - load_cache_slot = tl.load( - cache_indices_ptr + batch_idx * cache_indices_stride + init_state_idx - ).to(tl.int64) - if load_cache_slot == null_block_id: - return - elif HAS_CACHE_INDICES: - cache_index = tl.load(cache_indices_ptr + batch_idx).to(tl.int64) - if cache_index == null_block_id: - return - load_cache_slot = cache_index - else: - load_cache_slot = batch_idx.to(tl.int64) - - # Load D value - D_val = 0.0 - if HAS_D: - D_val = tl.load(D_ptr + dim_idx).to(tl.float32) - - # Load delta_bias value - delta_bias_val = 0.0 - if HAS_DELTA_BIAS: - delta_bias_val = tl.load(delta_bias_ptr + dim_idx).to(tl.float32) - - # Load A values for this dim - shape (dstate,) - dstate_offs = tl.arange(0, BLOCK_DSTATE) - dstate_mask = dstate_offs < dstate - A_vals = tl.load( - A_ptr + dim_idx * A_d_stride + dstate_offs * A_dstate_stride, - mask=dstate_mask, - other=0.0, - ).to(tl.float32) - - # Initialize state vector - state = tl.zeros((BLOCK_DSTATE,), dtype=tl.float32) - - # Load initial state if available - has_init = False - if has_initial_state_ptr is not None: - has_init = tl.load(has_initial_state_ptr + batch_idx) - if has_init: - state = tl.load( - ssm_states_ptr - + load_cache_slot * ssm_batch_stride - + dim_idx * ssm_dim_stride - + dstate_offs * ssm_dstate_stride, - mask=dstate_mask, - other=0.0, - ).to(tl.float32) - - # Compute base addresses for u and delta - if IS_VARLEN: - u_base = u_ptr + dim_idx * u_d_stride + seq_start * u_batch_stride - delta_base = ( - delta_ptr + dim_idx * delta_d_stride + seq_start * delta_batch_stride - ) - out_base = ( - out_ptr + dim_idx * out_d_stride + seq_start * out_batch_stride - ) - B_base = B_ptr + group_idx * B_group_stride + seq_start * B_batch_stride - C_base = C_ptr + group_idx * C_group_stride + seq_start * C_batch_stride - else: - u_base = u_ptr + batch_idx * u_batch_stride + dim_idx * u_d_stride - delta_base = ( - delta_ptr + batch_idx * delta_batch_stride + dim_idx * delta_d_stride - ) - out_base = ( - out_ptr + batch_idx * out_batch_stride + dim_idx * out_d_stride - ) - B_base = B_ptr + batch_idx * B_batch_stride + group_idx * B_group_stride - C_base = C_ptr + batch_idx * C_batch_stride + group_idx * C_group_stride - - if HAS_Z: - if IS_VARLEN: - z_base = z_ptr + dim_idx * z_d_stride + seq_start * z_batch_stride - out_z_base = ( - out_z_ptr - + dim_idx * out_z_d_stride - + seq_start * out_z_batch_stride - ) - else: - z_base = z_ptr + batch_idx * z_batch_stride + dim_idx * z_d_stride - out_z_base = ( - out_z_ptr - + batch_idx * out_z_batch_stride - + dim_idx * out_z_d_stride - ) - - # Determine chunk boundaries for APC mode - if CACHE_ENABLED: - last_chunk_idx = tl.load(last_chunk_indices_ptr + batch_idx).to(tl.int32) - if batch_idx == 0: - first_chunk_idx = 0 - else: - first_chunk_idx = ( - tl.load(last_chunk_indices_ptr + batch_idx - 1).to(tl.int32) + 1 - ) - n_chunks = last_chunk_idx - first_chunk_idx + 1 - first_chunk_tokens = ( - tl.load(cu_chunk_seqlen_ptr + first_chunk_idx + 1).to(tl.int32) - - tl.load(cu_chunk_seqlen_ptr + first_chunk_idx).to(tl.int32) - ) - block_idx_first = tl.load(block_idx_first_ptr + batch_idx).to(tl.int32) - chunk_start_offset = 0 - if n_chunks > 1 and first_chunk_tokens < block_size: - chunk_start_offset = block_size - first_chunk_tokens - current_position = block_idx_first * block_size + chunk_start_offset - else: - n_chunks = 1 - first_chunk_idx = 0 - - # Sequential scan over the sequence - tokens_processed = 0 - for chunk in range(0, n_chunks if CACHE_ENABLED else 1): - if CACHE_ENABLED: - chunk_tokens = ( - tl.load( - cu_chunk_seqlen_ptr + first_chunk_idx + chunk + 1 - ).to(tl.int32) - - tl.load( - cu_chunk_seqlen_ptr + first_chunk_idx + chunk - ).to(tl.int32) - ) - else: - chunk_tokens = actual_seqlen - - for local_pos in range(chunk_tokens): - pos = tokens_processed + local_pos - # Load u value - u_val = tl.load(u_base + pos).to(tl.float32) - - # Load delta value - delta_val = tl.load(delta_base + pos).to(tl.float32) - - # Apply delta bias - if HAS_DELTA_BIAS: - delta_val = delta_val + delta_bias_val - - # Apply softplus - if delta_softplus: - delta_val = _softplus(delta_val) - - delta_u = delta_val * u_val - - # Compute dA = exp(delta * A) for all dstate elements - dA = tl.exp(delta_val * A_vals) - - # Load B values for this position - B_vals = tl.load( - B_base + dstate_offs * B_dstate_stride + pos, - mask=dstate_mask, - other=0.0, - ).to(tl.float32) - - # Load C values for this position - C_vals = tl.load( - C_base + dstate_offs * C_dstate_stride + pos, - mask=dstate_mask, - other=0.0, - ).to(tl.float32) - - # Update state: state = dA * state + delta * u * B - state = dA * state + delta_u * B_vals - - # Compute output: out = sum(state * C) + D * u - out_val = tl.sum(state * C_vals, axis=0) - if HAS_D: - out_val = out_val + D_val * u_val - - # Store output - tl.store(out_base + pos, out_val.to(out_ptr.dtype.element_ty)) - - if HAS_Z: - z_val = tl.load(z_base + pos).to(tl.float32) - out_z_val = out_val * z_val / (1.0 + tl.exp(-z_val)) - tl.store( - out_z_base + pos, - out_z_val.to(out_z_ptr.dtype.element_ty), - ) - - tokens_processed += chunk_tokens - - # Store intermediate state for APC mode - if CACHE_ENABLED: - if chunk == n_chunks - 1: - store_slot = tl.load( - cache_indices_ptr - + batch_idx * cache_indices_stride - + tl.load(block_idx_last_ptr + batch_idx).to(tl.int32) - ).to(tl.int64) - else: - block_idx_done = ( - current_position + chunk_tokens - 1 - ) // block_size - store_slot = tl.load( - cache_indices_ptr - + batch_idx * cache_indices_stride - + block_idx_done - ).to(tl.int64) - - tl.store( - ssm_states_ptr - + store_slot * ssm_batch_stride - + dim_idx * ssm_dim_stride - + dstate_offs * ssm_dstate_stride, - state.to(ssm_states_ptr.dtype.element_ty), - mask=dstate_mask, - ) - current_position += chunk_tokens - - # Store final state for non-APC mode - if not CACHE_ENABLED: - tl.store( - ssm_states_ptr - + load_cache_slot * ssm_batch_stride - + dim_idx * ssm_dim_stride - + dstate_offs * ssm_dstate_stride, - state.to(ssm_states_ptr.dtype.element_ty), - mask=dstate_mask, - ) - - -def selective_scan_fwd_triton( - u: torch.Tensor, - delta: torch.Tensor, - A: torch.Tensor, - B: torch.Tensor, - C: torch.Tensor, - D_: torch.Tensor | None, - z_: torch.Tensor | None, - delta_bias_: torch.Tensor | None, - delta_softplus: bool, - query_start_loc: torch.Tensor | None, - cache_indices: torch.Tensor | None, - has_initial_state: torch.Tensor | None, - ssm_states: torch.Tensor, - null_block_id: int, - block_size: int = 1024, - block_idx_first_scheduled_token: torch.Tensor | None = None, - block_idx_last_scheduled_token: torch.Tensor | None = None, - initial_state_idx: torch.Tensor | None = None, - cu_chunk_seqlen: torch.Tensor | None = None, - last_chunk_indices: torch.Tensor | None = None, -): - """ - Triton implementation of selective scan forward pass. - - This writes output in-place to delta (when z is None) or z (when z is - provided), matching the CUDA kernel's behavior. - - See selective_scan_fn() in mamba_ssm.py for parameter documentation. - """ - varlen = query_start_loc is not None - batch_size = ( - (query_start_loc.shape[0] - 1) if varlen else u.shape[0] - ) - dim = u.shape[0] if varlen else u.shape[1] - total_seqlen = u.shape[1] if varlen else u.shape[2] - dstate = A.size(1) - n_groups = B.size(0) if varlen else B.size(1) - dim_ngroups_ratio = dim // n_groups - - has_z = z_ is not None - has_D = D_ is not None - has_delta_bias = delta_bias_ is not None - has_cache_indices = cache_indices is not None - cache_enabled = block_idx_first_scheduled_token is not None - - # out and out_z alias delta and z respectively - out = delta - out_z = z_ if has_z else delta # dummy, won't be used if not has_z - - BLOCK_DSTATE = triton.next_power_of_2(dstate) - - # Compute strides - if varlen: - u_batch_stride = u.stride(1) - u_d_stride = u.stride(0) - delta_batch_stride = delta.stride(1) - delta_d_stride = delta.stride(0) - B_batch_stride = B.stride(2) - B_group_stride = B.stride(0) - B_dstate_stride = B.stride(1) - C_batch_stride = C.stride(2) - C_group_stride = C.stride(0) - C_dstate_stride = C.stride(1) - out_batch_stride = out.stride(1) - out_d_stride = out.stride(0) - if has_z: - z_batch_stride = z_.stride(1) - z_d_stride = z_.stride(0) - out_z_batch_stride = out_z.stride(1) - out_z_d_stride = out_z.stride(0) - else: - z_batch_stride = 0 - z_d_stride = 0 - out_z_batch_stride = 0 - out_z_d_stride = 0 - else: - u_batch_stride = u.stride(0) - u_d_stride = u.stride(1) - delta_batch_stride = delta.stride(0) - delta_d_stride = delta.stride(1) - B_batch_stride = B.stride(0) - B_group_stride = B.stride(1) - B_dstate_stride = B.stride(2) - C_batch_stride = C.stride(0) - C_group_stride = C.stride(1) - C_dstate_stride = C.stride(2) - out_batch_stride = out.stride(0) - out_d_stride = out.stride(1) - if has_z: - z_batch_stride = z_.stride(0) - z_d_stride = z_.stride(1) - out_z_batch_stride = out_z.stride(0) - out_z_d_stride = out_z.stride(1) - else: - z_batch_stride = 0 - z_d_stride = 0 - out_z_batch_stride = 0 - out_z_d_stride = 0 - - ssm_batch_stride = ssm_states.stride(0) - ssm_dim_stride = ssm_states.stride(1) - ssm_dstate_stride = ssm_states.stride(2) - cache_indices_stride = cache_indices.stride(0) if has_cache_indices else 0 - - grid = (batch_size, dim) - _selective_scan_fwd_kernel[grid]( - u, - delta, - A, - B, - C, - D_ if has_D else u, # dummy, won't be dereferenced - z_ if has_z else u, # dummy - delta_bias_ if has_delta_bias else u, # dummy - out, - out_z, - ssm_states, - query_start_loc if varlen else u, # dummy - cache_indices if has_cache_indices else u, # dummy - has_initial_state, - # APC pointers - block_idx_first_scheduled_token if cache_enabled else u, - block_idx_last_scheduled_token if cache_enabled else u, - initial_state_idx if cache_enabled else u, - cu_chunk_seqlen if cache_enabled else u, - last_chunk_indices if cache_enabled else u, - # Dimensions - batch_size, - dim, - total_seqlen, - dstate, - n_groups, - dim_ngroups_ratio, - # Strides - u_batch_stride, - u_d_stride, - delta_batch_stride, - delta_d_stride, - A.stride(0), - A.stride(1), - B_batch_stride, - B_group_stride, - B_dstate_stride, - C_batch_stride, - C_group_stride, - C_dstate_stride, - z_batch_stride, - z_d_stride, - out_batch_stride, - out_d_stride, - out_z_batch_stride, - out_z_d_stride, - ssm_batch_stride, - ssm_dim_stride, - ssm_dstate_stride, - cache_indices_stride, - null_block_id, - block_size, - # Compile-time constants - delta_softplus=delta_softplus, - HAS_D=has_D, - HAS_Z=has_z, - HAS_DELTA_BIAS=has_delta_bias, - IS_VARLEN=varlen, - HAS_CACHE_INDICES=has_cache_indices, - CACHE_ENABLED=cache_enabled, - BLOCK_DSTATE=BLOCK_DSTATE, - ) From 5149c8dfcc708fb5875653df935045c82c5a5739 Mon Sep 17 00:00:00 2001 From: Marceli Fylcek Date: Fri, 29 May 2026 14:51:33 +0300 Subject: [PATCH 4/5] Precommit Signed-off-by: Marceli Fylcek --- vllm/_xpu_ops.py | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/vllm/_xpu_ops.py b/vllm/_xpu_ops.py index 61d024460f88..237825336f90 100644 --- a/vllm/_xpu_ops.py +++ b/vllm/_xpu_ops.py @@ -854,7 +854,9 @@ def selective_scan_fwd( ) -> None: varlen = query_start_loc is not None batch_size = ( - (query_start_loc.shape[0] - 1) if varlen else u.shape[0] + (query_start_loc.shape[0] - 1) + if query_start_loc is not None + else u.shape[0] ) dim = u.shape[0] if varlen else u.shape[1] total_seqlen = u.shape[1] if varlen else u.shape[2] @@ -870,7 +872,7 @@ def selective_scan_fwd( # out and out_z alias delta and z respectively out = delta - out_z = z_ if has_z else delta # dummy, won't be used if not has_z + out_z = z_ if z_ is not None else delta # won't be used if not has_z BLOCK_DSTATE = triton.next_power_of_2(dstate) @@ -888,7 +890,7 @@ def selective_scan_fwd( C_dstate_stride = C.stride(1) out_batch_stride = out.stride(1) out_d_stride = out.stride(0) - if has_z: + if z_ is not None: z_batch_stride = z_.stride(1) z_d_stride = z_.stride(0) out_z_batch_stride = out_z.stride(1) @@ -911,7 +913,7 @@ def selective_scan_fwd( C_dstate_stride = C.stride(2) out_batch_stride = out.stride(0) out_d_stride = out.stride(1) - if has_z: + if z_ is not None: z_batch_stride = z_.stride(0) z_d_stride = z_.stride(1) out_z_batch_stride = out_z.stride(0) @@ -926,7 +928,7 @@ def selective_scan_fwd( ssm_dim_stride = ssm_states.stride(1) ssm_dstate_stride = ssm_states.stride(2) cache_indices_stride = ( - cache_indices.stride(0) if has_cache_indices else 0 + cache_indices.stride(0) if cache_indices is not None else 0 ) grid = (batch_size, dim) From 545725d06313f4a0fac93c1d342e08088f6facb3 Mon Sep 17 00:00:00 2001 From: Marceli Fylcek Date: Fri, 29 May 2026 15:05:24 +0300 Subject: [PATCH 5/5] Formatting Signed-off-by: Marceli Fylcek --- vllm/_xpu_ops.py | 38 ++++++++++++-------------------------- 1 file changed, 12 insertions(+), 26 deletions(-) diff --git a/vllm/_xpu_ops.py b/vllm/_xpu_ops.py index 237825336f90..a48777e00511 100644 --- a/vllm/_xpu_ops.py +++ b/vllm/_xpu_ops.py @@ -521,9 +521,7 @@ def _selective_scan_fwd_kernel( delta_base = ( delta_ptr + dim_idx * delta_d_stride + seq_start * delta_batch_stride ) - out_base = ( - out_ptr + dim_idx * out_d_stride + seq_start * out_batch_stride - ) + out_base = out_ptr + dim_idx * out_d_stride + seq_start * out_batch_stride B_base = B_ptr + group_idx * B_group_stride + seq_start * B_batch_stride C_base = C_ptr + group_idx * C_group_stride + seq_start * C_batch_stride else: @@ -531,9 +529,7 @@ def _selective_scan_fwd_kernel( delta_base = ( delta_ptr + batch_idx * delta_batch_stride + dim_idx * delta_d_stride ) - out_base = ( - out_ptr + batch_idx * out_batch_stride + dim_idx * out_d_stride - ) + out_base = out_ptr + batch_idx * out_batch_stride + dim_idx * out_d_stride B_base = B_ptr + batch_idx * B_batch_stride + group_idx * B_group_stride C_base = C_ptr + batch_idx * C_batch_stride + group_idx * C_group_stride @@ -541,16 +537,12 @@ def _selective_scan_fwd_kernel( if IS_VARLEN: z_base = z_ptr + dim_idx * z_d_stride + seq_start * z_batch_stride out_z_base = ( - out_z_ptr - + dim_idx * out_z_d_stride - + seq_start * out_z_batch_stride + out_z_ptr + dim_idx * out_z_d_stride + seq_start * out_z_batch_stride ) else: z_base = z_ptr + batch_idx * z_batch_stride + dim_idx * z_d_stride out_z_base = ( - out_z_ptr - + batch_idx * out_z_batch_stride - + dim_idx * out_z_d_stride + out_z_ptr + batch_idx * out_z_batch_stride + dim_idx * out_z_d_stride ) # Determine chunk boundaries for APC mode @@ -563,10 +555,9 @@ def _selective_scan_fwd_kernel( tl.load(last_chunk_indices_ptr + batch_idx - 1).to(tl.int32) + 1 ) n_chunks = last_chunk_idx - first_chunk_idx + 1 - first_chunk_tokens = ( - tl.load(cu_chunk_seqlen_ptr + first_chunk_idx + 1).to(tl.int32) - - tl.load(cu_chunk_seqlen_ptr + first_chunk_idx).to(tl.int32) - ) + first_chunk_tokens = tl.load(cu_chunk_seqlen_ptr + first_chunk_idx + 1).to( + tl.int32 + ) - tl.load(cu_chunk_seqlen_ptr + first_chunk_idx).to(tl.int32) block_idx_first = tl.load(block_idx_first_ptr + batch_idx).to(tl.int32) chunk_start_offset = 0 if n_chunks > 1 and first_chunk_tokens < block_size: @@ -580,13 +571,10 @@ def _selective_scan_fwd_kernel( tokens_processed = 0 for chunk in range(0, n_chunks if CACHE_ENABLED else 1): if CACHE_ENABLED: - chunk_tokens = ( - tl.load( - cu_chunk_seqlen_ptr + first_chunk_idx + chunk + 1 - ).to(tl.int32) - - tl.load( - cu_chunk_seqlen_ptr + first_chunk_idx + chunk - ).to(tl.int32) + chunk_tokens = tl.load( + cu_chunk_seqlen_ptr + first_chunk_idx + chunk + 1 + ).to(tl.int32) - tl.load(cu_chunk_seqlen_ptr + first_chunk_idx + chunk).to( + tl.int32 ) else: chunk_tokens = actual_seqlen @@ -656,9 +644,7 @@ def _selective_scan_fwd_kernel( + tl.load(block_idx_last_ptr + batch_idx).to(tl.int32) ).to(tl.int64) else: - block_idx_done = ( - current_position + chunk_tokens - 1 - ) // block_size + block_idx_done = (current_position + chunk_tokens - 1) // block_size store_slot = tl.load( cache_indices_ptr + batch_idx * cache_indices_stride