diff --git a/vllm/attention/ops/triton_unified_attention.py b/vllm/attention/ops/triton_unified_attention.py index 565be1c39b..80154ea448 100644 --- a/vllm/attention/ops/triton_unified_attention.py +++ b/vllm/attention/ops/triton_unified_attention.py @@ -64,6 +64,7 @@ def kernel_unified_attention_2d( seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] qq_bias_ptr, # [num_query_tokens, num_query_tokens] + cos_sin_cache_ptr, # [max_model_len, head_size] scale, # float32 k_scale, # float32 v_scale, # float32 @@ -86,6 +87,7 @@ def kernel_unified_attention_2d( USE_SOFTCAP: tl.constexpr, # bool USE_SINKS: tl.constexpr, # bool SLIDING_WINDOW: tl.constexpr, # int + FUSE_ROPE: tl.constexpr, # bool stride_k_cache_0: tl.int64, # int stride_k_cache_1: tl.int64, # int stride_k_cache_2: tl.int64, # int @@ -94,6 +96,8 @@ def kernel_unified_attention_2d( stride_v_cache_1: tl.int64, # int stride_v_cache_2: tl.int64, # int stride_v_cache_3: tl.constexpr, # int + stride_cs_cache_0: tl.int64, # int + stride_cs_cache_1: tl.constexpr, # int query_start_len_ptr, # [num_seqs+1] BLOCK_Q: tl.constexpr, # int num_seqs: tl.int32, @@ -128,20 +132,43 @@ def kernel_unified_attention_2d( query_offset_0 = cur_batch_in_all_start_index + query_pos query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv - query_offset = ( + + offs_d_new = tl.arange(0, HEAD_SIZE_PADDED // 2) + + query_offset_a = ( query_offset_0[:, None] * query_stride_0 + query_offset_1[:, None] * query_stride_1 - + offs_d[None, :] + + offs_d_new[None, :] + ) + + query_offset_b = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d_new[None, :] + + HEAD_SIZE_PADDED // 2 ) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + + dim_mask_a = tl.where(offs_d_new < HEAD_SIZE, 1, 0).to(tl.int1) + dim_mask_b = tl.where((HEAD_SIZE_PADDED // 2 + offs_d_new) < HEAD_SIZE, 1, 0).to( + tl.int1 + ) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) - # Q : (BLOCK_M, HEAD_SIZE_PADDED) - Q = tl.load( - query_ptr + query_offset, - mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + # Q_a : (BLOCK_M, HEAD_SIZE_PADDED // 2) + Q_a = tl.load( + query_ptr + query_offset_a, + mask=dim_mask_a[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + # Q_b : (BLOCK_M, HEAD_SIZE_PADDED // 2) + Q_b = tl.load( + query_ptr + query_offset_b, + mask=dim_mask_b[None, :] & query_mask_0[:, None] & query_mask_1[:, None], other=0.0, ) @@ -225,7 +252,6 @@ def kernel_unified_attention_2d( physical_block_idx = tl.load( block_tables_ptr + block_table_offset + seq_offset // BLOCK_SIZE ).to(tl.int64) - v_offset = ( physical_block_idx[:, None] * stride_v_cache_0 + kv_head_idx * stride_v_cache_2 @@ -233,27 +259,78 @@ def kernel_unified_attention_2d( + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 ) - k_offset = ( + k_offset_a = ( physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 - + offs_d[:, None] * stride_k_cache_3 + + offs_d_new[:, None] * stride_k_cache_3 + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 ) - # K : (HEAD_SIZE, TILE_SIZE) - K_load = tl.load( - key_cache_ptr + k_offset, - mask=dim_mask[:, None] & tile_mask[None, :], + k_offset_b = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + (HEAD_SIZE_PADDED // 2 + offs_d_new[:, None]) * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) + + # K_a : (HEAD_SIZE_PADDED // 2, TILE_SIZE) + K_a_load = tl.load( + key_cache_ptr + k_offset_a, + mask=dim_mask_a[:, None] & tile_mask[None, :], other=0.0, ) - if K_load.dtype.is_fp8(): - if Q.dtype.is_fp8(): - K = K_load + if K_a_load.dtype.is_fp8(): + if Q_a.dtype.is_fp8(): + K_a = K_a_load else: - K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + K_a = (K_a_load.to(tl.float32) * tl.load(k_scale)).to(Q_a.dtype) + else: + K_a = K_a_load + + # K_b : (HEAD_SIZE_PADDED // 2, TILE_SIZE) + K_b_load = tl.load( + key_cache_ptr + k_offset_b, + mask=dim_mask_b[:, None] & tile_mask[None, :], + other=0.0, + ) + + if K_b_load.dtype.is_fp8(): + if Q_b.dtype.is_fp8(): + K_b = K_b_load + else: + K_b = (K_b_load.to(tl.float32) * tl.load(k_scale)).to(Q_b.dtype) + else: + K_b = K_b_load + + if FUSE_ROPE: + cos_cache_offset = ( + seq_offset[None, :] * stride_cs_cache_0 + + offs_d_new[:, None] * stride_cs_cache_1 + ) + + sin_cache_offset = ( + seq_offset[None, :] * stride_cs_cache_0 + + (HEAD_SIZE_PADDED // 2 + offs_d_new[:, None]) * stride_cs_cache_1 + ) + + cos = tl.load( + cos_sin_cache_ptr + cos_cache_offset, + mask=dim_mask_a[:, None] & tile_mask[None, :], + other=0.0, + ).to(K_a.dtype) + + sin = tl.load( + cos_sin_cache_ptr + sin_cache_offset, + mask=dim_mask_b[:, None] & tile_mask[None, :], + other=0.0, + ).to(K_b.dtype) + + K_rot_a = K_a * cos - K_b * sin + K_rot_b = K_b * cos + K_a * sin else: - K = K_load + K_rot_a = K_a + K_rot_b = K_b # V : (TILE_SIZE, HEAD_SIZE) V_load = tl.load( @@ -263,10 +340,10 @@ def kernel_unified_attention_2d( ) if V_load.dtype.is_fp8(): - if Q.dtype.is_fp8(): + if Q_a.dtype.is_fp8(): V = V_load else: - V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q_a.dtype) else: V = V_load @@ -274,8 +351,8 @@ def kernel_unified_attention_2d( # S : (BLOCK_M, TILE_SIZE) S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) - - S += scale * tl.dot(Q, K) + S += scale * tl.dot(Q_a, K_rot_a) + S += scale * tl.dot(Q_b, K_rot_b) if USE_SOFTCAP: S = apply_softcap(S, softcap) @@ -366,6 +443,7 @@ def kernel_unified_attention_3d( seq_lens_ptr, # [num_seqs] alibi_slopes_ptr, # [num_query_heads] qq_bias_ptr, # [num_query_tokens, num_query_tokens] + cos_sin_cache_ptr, # [max_model_len, head_size] scale, # float32 k_scale, # float32 v_scale, # float32 @@ -385,6 +463,7 @@ def kernel_unified_attention_3d( USE_SOFTCAP: tl.constexpr, # bool USE_SINKS: tl.constexpr, # bool SLIDING_WINDOW: tl.constexpr, # int + FUSE_ROPE: tl.constexpr, # bool stride_k_cache_0: tl.int64, # int stride_k_cache_1: tl.int64, # int stride_k_cache_2: tl.int64, # int @@ -393,6 +472,8 @@ def kernel_unified_attention_3d( stride_v_cache_1: tl.int64, # int stride_v_cache_2: tl.int64, # int stride_v_cache_3: tl.constexpr, # int + stride_cs_cache_0: tl.int64, # int + stride_cs_cache_1: tl.constexpr, # int query_start_len_ptr, # [num_seqs+1] BLOCK_Q: tl.constexpr, # int num_seqs: tl.int32, @@ -436,20 +517,43 @@ def kernel_unified_attention_3d( query_offset_0 = cur_batch_in_all_start_index + query_pos query_offset_1 = kv_head_idx * num_queries_per_kv + offs_m % num_queries_per_kv - query_offset = ( + + offs_d_new = tl.arange(0, HEAD_SIZE_PADDED // 2) + + query_offset_a = ( query_offset_0[:, None] * query_stride_0 + query_offset_1[:, None] * query_stride_1 - + offs_d[None, :] + + offs_d_new[None, :] + ) + + query_offset_b = ( + query_offset_0[:, None] * query_stride_0 + + query_offset_1[:, None] * query_stride_1 + + offs_d_new[None, :] + + HEAD_SIZE_PADDED // 2 ) dim_mask = tl.where(offs_d < HEAD_SIZE, 1, 0).to(tl.int1) + + dim_mask_a = tl.where(offs_d_new < HEAD_SIZE, 1, 0).to(tl.int1) + dim_mask_b = tl.where((HEAD_SIZE_PADDED // 2 + offs_d_new) < HEAD_SIZE, 1, 0).to( + tl.int1 + ) + query_mask_0 = tl.where(query_pos < cur_batch_query_len, 1, 0).to(tl.int1) query_mask_1 = tl.where(query_offset_1 < num_query_heads, 1, 0).to(tl.int1) - # Q : (BLOCK_M, HEAD_SIZE_PADDED) - Q = tl.load( - query_ptr + query_offset, - mask=dim_mask[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + # Q_a : (BLOCK_M, HEAD_SIZE_PADDED // 2) + Q_a = tl.load( + query_ptr + query_offset_a, + mask=dim_mask_a[None, :] & query_mask_0[:, None] & query_mask_1[:, None], + other=0.0, + ) + + # Q_b : (BLOCK_M, HEAD_SIZE_PADDED // 2) + Q_b = tl.load( + query_ptr + query_offset_b, + mask=dim_mask_b[None, :] & query_mask_0[:, None] & query_mask_1[:, None], other=0.0, ) @@ -522,29 +626,80 @@ def kernel_unified_attention_3d( + (seq_offset % BLOCK_SIZE)[:, None] * stride_v_cache_1 ) - k_offset = ( + k_offset_a = ( + physical_block_idx[None, :] * stride_k_cache_0 + + kv_head_idx * stride_k_cache_2 + + offs_d_new[:, None] * stride_k_cache_3 + + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 + ) + + k_offset_b = ( physical_block_idx[None, :] * stride_k_cache_0 + kv_head_idx * stride_k_cache_2 - + offs_d[:, None] * stride_k_cache_3 + + (HEAD_SIZE_PADDED // 2 + offs_d_new[:, None]) * stride_k_cache_3 + (seq_offset % BLOCK_SIZE)[None, :] * stride_k_cache_1 ) - # K : (HEAD_SIZE, TILE_SIZE) - K_load = tl.load( - key_cache_ptr + k_offset, - mask=dim_mask[:, None] & tile_mask[None, :], + # K_a : (HEAD_SIZE_PADDED // 2, TILE_SIZE) + K_a_load = tl.load( + key_cache_ptr + k_offset_a, + mask=dim_mask_a[:, None] & tile_mask[None, :], other=0.0, ) - if K_load.dtype.is_fp8(): - if Q.dtype.is_fp8(): - K = K_load + if K_a_load.dtype.is_fp8(): + if Q_a.dtype.is_fp8(): + K_a = K_a_load else: - K = (K_load.to(tl.float32) * tl.load(k_scale)).to(Q.dtype) + K_a = (K_a_load.to(tl.float32) * tl.load(k_scale)).to(Q_a.dtype) else: - K = K_load + K_a = K_a_load - # V : (TILE_SIZE, HEAD_SIZE) + # K_b : (HEAD_SIZE_PADDED // 2, TILE_SIZE) + K_b_load = tl.load( + key_cache_ptr + k_offset_b, + mask=dim_mask_b[:, None] & tile_mask[None, :], + other=0.0, + ) + + if K_b_load.dtype.is_fp8(): + if Q_b.dtype.is_fp8(): + K_b = K_b_load + else: + K_b = (K_b_load.to(tl.float32) * tl.load(k_scale)).to(Q_b.dtype) + else: + K_b = K_b_load + + if FUSE_ROPE: + cos_cache_offset = ( + seq_offset[None, :] * stride_cs_cache_0 + + offs_d_new[:, None] * stride_cs_cache_1 + ) + + sin_cache_offset = ( + seq_offset[None, :] * stride_cs_cache_0 + + (HEAD_SIZE_PADDED // 2 + offs_d_new[:, None]) * stride_cs_cache_1 + ) + + cos = tl.load( + cos_sin_cache_ptr + cos_cache_offset, + mask=dim_mask_a[:, None] & tile_mask[None, :], + other=0.0, + ).to(K_a.dtype) + + sin = tl.load( + cos_sin_cache_ptr + sin_cache_offset, + mask=dim_mask_b[:, None] & tile_mask[None, :], + other=0.0, + ).to(K_b.dtype) + + K_rot_a = K_a * cos - K_b * sin + K_rot_b = K_b * cos + K_a * sin + else: + K_rot_a = K_a + K_rot_b = K_b + + # V : (TILE_SIZE, HEAD_SIZE_PADDED) V_load = tl.load( value_cache_ptr + v_offset, mask=dim_mask[None, :] & tile_mask[:, None], @@ -552,10 +707,10 @@ def kernel_unified_attention_3d( ) if V_load.dtype.is_fp8(): - if Q.dtype.is_fp8(): + if Q_a.dtype.is_fp8(): V = V_load else: - V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q.dtype) + V = (V_load.to(tl.float32) * tl.load(v_scale)).to(Q_a.dtype) else: V = V_load @@ -563,7 +718,8 @@ def kernel_unified_attention_3d( # S : (BLOCK_M, TILE_SIZE) S = tl.zeros(shape=(BLOCK_M, TILE_SIZE), dtype=tl.float32) - S += scale * tl.dot(Q, K) + S += scale * tl.dot(Q_a, K_rot_a) + S += scale * tl.dot(Q_b, K_rot_b) if USE_SOFTCAP: S = apply_softcap(S, softcap) @@ -754,6 +910,7 @@ def unified_attention( qq_bias=None, # Optional tensor for sinks sinks=None, + cos_sin_cache=None, ): assert causal, "Only causal attention is supported" assert q_descale is None, "Q scales not supported" @@ -763,6 +920,7 @@ def unified_attention( use_alibi_slopes = alibi_slopes is not None use_qq_bias = qq_bias is not None + fuse_rope = cos_sin_cache is not None block_size = v.shape[1] num_seqs = len(seqused_k) @@ -810,6 +968,7 @@ def unified_attention( seq_lens_ptr=seqused_k, alibi_slopes_ptr=alibi_slopes, qq_bias_ptr=qq_bias, + cos_sin_cache_ptr=cos_sin_cache, scale=softmax_scale, k_scale=k_descale, v_scale=v_descale, @@ -832,6 +991,7 @@ def unified_attention( USE_SOFTCAP=(softcap > 0), USE_SINKS=(sinks is not None), SLIDING_WINDOW=(1 + window_size[0]), + FUSE_ROPE=fuse_rope, stride_k_cache_0=k.stride(0), stride_k_cache_1=k.stride(1), stride_k_cache_2=k.stride(2), @@ -840,6 +1000,8 @@ def unified_attention( stride_v_cache_1=v.stride(1), stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), + stride_cs_cache_0=cos_sin_cache.stride(0) if fuse_rope else 0, + stride_cs_cache_1=cos_sin_cache.stride(1) if fuse_rope else 0, query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, @@ -886,6 +1048,7 @@ def unified_attention( seq_lens_ptr=seqused_k, alibi_slopes_ptr=alibi_slopes, qq_bias_ptr=qq_bias, + cos_sin_cache_ptr=cos_sin_cache, scale=softmax_scale, k_scale=k_descale, v_scale=v_descale, @@ -905,6 +1068,7 @@ def unified_attention( USE_SOFTCAP=(softcap > 0), USE_SINKS=(sinks is not None), SLIDING_WINDOW=(1 + window_size[0]), + FUSE_ROPE=fuse_rope, stride_k_cache_0=k.stride(0), stride_k_cache_1=k.stride(1), stride_k_cache_2=k.stride(2), @@ -913,6 +1077,8 @@ def unified_attention( stride_v_cache_1=v.stride(1), stride_v_cache_2=v.stride(2), stride_v_cache_3=v.stride(3), + stride_cs_cache_0=cos_sin_cache.stride(0) if fuse_rope else 0, + stride_cs_cache_1=cos_sin_cache.stride(1) if fuse_rope else 0, query_start_len_ptr=cu_seqlens_q, BLOCK_Q=BLOCK_Q, num_seqs=num_seqs, diff --git a/vllm/envs.py b/vllm/envs.py index 67b508cb2d..97c5b77b9a 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -228,7 +228,6 @@ VLLM_V1_SPANS_DEBUG: bool = False VLLM_V1_SPANS_TOKEN_PLUS: int = -1 VLLM_V1_SPANS_TOKEN_CROSS: int = -1 - VLLM_V1_SPANS_DISABLE_REPOSITION: bool = False def get_default_cache_root(): @@ -1504,11 +1503,6 @@ def get_vllm_port() -> int | None: "VLLM_V1_SPANS_TOKEN_CROSS": lambda: int( os.environ.get("VLLM_V1_SPANS_TOKEN_CROSS", "-1") ), - # for block-attention, detected spans will be loaded but not repositioned - "VLLM_V1_SPANS_DISABLE_REPOSITION": lambda: os.environ.get( - "VLLM_V1_SPANS_DISABLE_REPOSITION", "False" - ) - == "True", } # --8<-- [end:env-vars-definition] diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index 2fc00130da..d85ded63fa 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -4,6 +4,7 @@ import torch +import vllm.envs as envs from vllm.model_executor.custom_op import CustomOp from .common import apply_rotary_emb_torch @@ -106,15 +107,12 @@ def forward_native( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None = None, - invert_rotation_angle: bool = False, # <- to unrope kv's ) -> tuple[torch.Tensor, torch.Tensor | None]: """A PyTorch-native implementation of forward().""" positions = positions.flatten() num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache.index_select(0, positions) cos, sin = cos_sin.chunk(2, dim=-1) - if invert_rotation_angle: - sin = -sin query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) @@ -124,7 +122,7 @@ def forward_native( query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape) # key may be None in some cases, e.g. cross-layer KV sharing - if key is not None: + if key is not None and not envs.VLLM_V1_SPANS_ENABLED: key_shape = key.shape key = key.view(num_tokens, -1, self.head_size) key_rot = key[..., : self.rotary_dim] @@ -159,7 +157,7 @@ def forward_cuda( ops.rotary_embedding( positions, query, - key, + None if envs.VLLM_V1_SPANS_ENABLED else key, self.head_size, self.cos_sin_cache, self.is_neox_style, diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 34280c2d37..0592aa8f96 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -265,7 +265,6 @@ def forward_native( query: torch.Tensor, key: torch.Tensor | None = None, offsets: torch.Tensor | None = None, - invert_rotation_angle: bool = False, ) -> tuple[torch.Tensor, torch.Tensor | None]: """PyTorch-native implementation equivalent to forward(). diff --git a/vllm/v1/attention/backends/triton_attn.py b/vllm/v1/attention/backends/triton_attn.py index 0590a87bf8..609e763077 100644 --- a/vllm/v1/attention/backends/triton_attn.py +++ b/vllm/v1/attention/backends/triton_attn.py @@ -7,6 +7,7 @@ import torch +import vllm.envs as envs from vllm.attention.backends.abstract import ( AttentionBackend, AttentionImpl, @@ -59,6 +60,8 @@ class TritonAttentionMetadata: prefix_kv_lens: torch.Tensor | None suffix_kv_lens: torch.Tensor | None + cos_sin_cache: torch.Tensor | None = None + # Optional aot scheduling scheduler_metadata: torch.Tensor | None = None prefix_scheduler_metadata: torch.Tensor | None = None @@ -127,6 +130,11 @@ def build( suffix_kv_lens = None prefix_scheduler_metadata = None + if envs.VLLM_V1_SPANS_ENABLED: + cos_sin_cache = common_attn_metadata.cos_sin_cache + else: + cos_sin_cache = None + attn_metadata = TritonAttentionMetadata( num_actual_tokens=num_actual_tokens, max_query_len=max_query_len, @@ -141,6 +149,7 @@ def build( prefix_kv_lens=prefix_kv_lens, suffix_kv_lens=suffix_kv_lens, prefix_scheduler_metadata=prefix_scheduler_metadata, + cos_sin_cache=cos_sin_cache, ) return attn_metadata @@ -343,6 +352,8 @@ def forward( descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1]) + cos_sin_cache = attn_metadata.cos_sin_cache + unified_attention( q=query[:num_actual_tokens], k=key_cache, @@ -363,6 +374,7 @@ def forward( v_descale=layer._v_scale.expand(descale_shape), sinks=self.sinks, output_scale=output_scale, + cos_sin_cache=cos_sin_cache, ) return output diff --git a/vllm/v1/attention/backends/utils.py b/vllm/v1/attention/backends/utils.py index 07dfbc766a..04985c4bfd 100644 --- a/vllm/v1/attention/backends/utils.py +++ b/vllm/v1/attention/backends/utils.py @@ -82,6 +82,8 @@ class CommonAttentionMetadata: block_table_tensor: torch.Tensor slot_mapping: torch.Tensor + cos_sin_cache: torch.Tensor | None = None + causal: bool = True # Needed by FastPrefillAttentionBuilder diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index 62be679309..9bfd8b69ee 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -3,7 +3,6 @@ from collections.abc import Iterable, Sequence from typing import Any -import vllm.envs as envs from vllm.distributed.kv_events import ( MEDIUM_GPU, AllBlocksCleared, @@ -242,8 +241,6 @@ def cache_full_blocks( if new_hashes is not None: new_hashes.append(maybe_convert_block_hash(block_hash)) - self._set_block_positions(new_full_blocks, blocks, request) - if self.enable_kv_cache_events: if num_cached_blocks == 0: parent_block_hash: ExternalBlockHash | None = None @@ -269,58 +266,6 @@ def cache_full_blocks( ) ) - def _set_block_positions( - self, - new_full_blocks: list[KVCacheBlock], - blocks: list[KVCacheBlock], - request: Request, - ): - """Sets the positions of new full blocks in the KV cache. - - This function assigns positions to newly filled blocks based - on their order within the provided block list. The position - corresponds to the location embedded in K vectors (if using RoPE) - in the KV cache and is critical for maintaining correct alignment, - especially when prompt positions differ between requests. - - Args: - new_full_blocks: List of KVCacheBlock objects that have been newly - filled and require position assignment. - blocks: List of all blocks associated with the current request, - used to determine the order in which positions are assigned. - request: The Request object containing token information for - debugging purposes. - - Note: - When VLLM_V1_SPANS_DEBUG is enabled, this function includes - debug logging that prints each block's tokens, to help - debug span-related workflows. - """ - pos = 0 - for blk in blocks: - if blk in new_full_blocks: - blk.position = pos - if envs.VLLM_V1_SPANS_DEBUG: - # this prints the tokens assigned to a new block - # in the KV cache - blk_tks = request.all_token_ids[pos : pos + 16] - assert blk.block_hash is not None - bhash = str(blk.block_hash)[:4] if blk.block_hash else None - print( - "[SPANS -> block_pool] assigning to pos", - pos, - "with hash", - bhash, - "block: ", - blk_tks, - ) - pos += 16 - if envs.VLLM_V1_SPANS_DEBUG: - print( - "[SPANS -> block_pool] assigned block count now ->", - len([b for b in self.blocks if b._block_hash]), - ) - def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: """Get new blocks from the free block pool. @@ -413,19 +358,11 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: blocks_list = list(ordered_blocks) for block in blocks_list: block.ref_cnt -= 1 - # remove duplicates (blocks can now appear twice) - block_ids = set() - blocks_list_filtered = [] - for block in blocks_list: - if block.block_id not in block_ids: - blocks_list_filtered.append(block) - block_ids.add(block.block_id) + # Remove duplicates while preserving order + dedup_bl = list({block.block_id: block for + block in blocks_list}.values()) self.free_block_queue.append_n( - [ - block - for block in blocks_list_filtered - if block.ref_cnt == 0 and not block.is_null - ] + [block for block in dedup_bl if block.ref_cnt == 0 and not block.is_null] ) def reset_prefix_cache(self) -> bool: diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 67eafe5d82..63a1ff06e4 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -6,7 +6,6 @@ from dataclasses import dataclass from typing import Literal, overload -import vllm.envs as envs from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator @@ -18,13 +17,6 @@ logger = init_logger(__name__) -@dataclass -class BlockRepositionRequest: - block_id: int - kvc_pos: int - prompt_pos: int - - @dataclass class KVCacheBlocks: """ @@ -34,7 +26,6 @@ class KVCacheBlocks: """ blocks: tuple[Sequence[KVCacheBlock], ...] - blocks_to_reposition: list[BlockRepositionRequest] """ `blocks[i][j]` refers to the i-th kv_cache_group and the j-th block of tokens.We don't use block of @@ -55,8 +46,7 @@ def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": tuple( list(itertools.chain(blk1, blk2)) for blk1, blk2 in zip(self.blocks, other.blocks) - ), - self.blocks_to_reposition + other.blocks_to_reposition, + ) ) @overload @@ -97,7 +87,7 @@ def new_empty(self) -> "KVCacheBlocks": """ Creates a new KVCacheBlocks instance with no blocks. """ - return KVCacheBlocks(tuple(() for _ in range(len(self.blocks))), []) + return KVCacheBlocks(tuple(() for _ in range(len(self.blocks)))) class KVCacheManager: @@ -159,7 +149,7 @@ def __init__( # # We use nested tuples to ensure the empty KVCacheBlocks is immutable. self.empty_kv_cache_blocks = KVCacheBlocks( - tuple(() for _ in range(self.num_kv_cache_groups)), [] + tuple(() for _ in range(self.num_kv_cache_groups)) ) @property @@ -215,58 +205,6 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: request.block_hashes, max_cache_hit_length ) ) - if envs.VLLM_V1_SPANS_DEBUG: - print( - "[SPANS -> kv_cache_manager] here's the blocks hashed in this request:", - [str(b)[:4] for b in request.block_hashes], - ) - kvcache_contents = [ - str(b.block_hash)[:4] if b.block_hash else None - for b in self.block_pool.blocks - if b._block_hash - ] - if len(kvcache_contents) > 32: - kvcache_contents = kvcache_contents[:32] + [ - "... (too long to print it all)" - ] - print( - "[SPANS -> kv_cache_manager] here's the contents of the kv cache:", - kvcache_contents, - ) - print( - "[SPANS -> kv_cache_manager] here's the number of blocks " - "that hit the cache:", - [ - str(b.block_hash)[:4] if b.block_hash else None - for b in computed_blocks[0] - ], - ) - - blocks_to_reposition = [] - if envs.VLLM_V1_SPANS_ENABLED: - # Spans does yet not support hybrid models - assert len(computed_blocks) == 1 - for i, b in enumerate(computed_blocks[0]): - prompt_pos = i * 16 - kvc_pos = b.position - if envs.VLLM_V1_SPANS_DEBUG: - print( - f"[SPANS -> kv_cache_manager] checking block " - f"{b.block_id} with prompot pos {prompt_pos} " - f"and kv pos {kvc_pos}" - ) - assert isinstance(kvc_pos, int) - if kvc_pos != prompt_pos: - if envs.VLLM_V1_SPANS_DEBUG: - print( - f"[SPANS -> kv_cache_manager] from pos: {kvc_pos} " - f"to prompt pos: {prompt_pos} repositioning needed" - ) - - blocks_to_reposition.append( - BlockRepositionRequest(b.block_id, kvc_pos, prompt_pos) - ) - b.position = int(prompt_pos) if self.log_stats: assert self.prefix_cache_stats is not None @@ -276,9 +214,7 @@ def get_computed_blocks(self, request: Request) -> tuple[KVCacheBlocks, int]: preempted=request.num_preemptions > 0, ) - return self.create_kv_cache_blocks( - computed_blocks, blocks_to_reposition - ), num_new_computed_tokens + return self.create_kv_cache_blocks(computed_blocks), num_new_computed_tokens def allocate_slots( self, @@ -384,7 +320,7 @@ def allocate_slots( # P/D: delay caching blocks if we have to recv from # remote. Update state for locally cached blocks. if not self.enable_caching or delay_cache_blocks: - return self.create_kv_cache_blocks(new_blocks, []) + return self.create_kv_cache_blocks(new_blocks) # NOTE(woosuk): We want to commit (cache) up to num_computed_tokens + # num_new_tokens, but must exclude "non-committable" tokens (e.g., @@ -395,7 +331,7 @@ def allocate_slots( ) self.coordinator.cache_blocks(request, num_tokens_to_cache) - return self.create_kv_cache_blocks(new_blocks, []) + return self.create_kv_cache_blocks(new_blocks) def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -467,7 +403,7 @@ def take_events(self) -> list[KVCacheEvent]: def get_blocks(self, request_id: str) -> KVCacheBlocks: """Get the blocks of a request.""" - return self.create_kv_cache_blocks(self.coordinator.get_blocks(request_id), []) + return self.create_kv_cache_blocks(self.coordinator.get_blocks(request_id)) def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: """Get the block ids of a request.""" @@ -479,13 +415,7 @@ def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: self.coordinator.cache_blocks(request, num_computed_tokens) def create_kv_cache_blocks( - self, - blocks: tuple[list[KVCacheBlock], ...], - blocks_to_reposition: list[BlockRepositionRequest], + self, blocks: tuple[list[KVCacheBlock], ...] ) -> KVCacheBlocks: # Only create new KVCacheBlocks for non-empty blocks - return ( - KVCacheBlocks(blocks, blocks_to_reposition) - if any(blocks) - else self.empty_kv_cache_blocks - ) + return KVCacheBlocks(blocks) if any(blocks) else self.empty_kv_cache_blocks diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index d6e2250f59..e2089cbf5f 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -111,8 +111,6 @@ class KVCacheBlock: block_id: int # Reference count. ref_cnt: int = 0 - # Position (corresponds to positional encodings position) - position: int | None = None # The hash key (block hash + group id) of the block, only available # when the block is full and cached. _block_hash: BlockHashWithGroupId | None = None diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index dff7ae389a..866136648b 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -19,7 +19,6 @@ from vllm.multimodal.inputs import MultiModalFeatureSpec from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams - from vllm.v1.core.kv_cache_manager import BlockRepositionRequest from vllm.v1.request import Request else: KVConnectorMetadata = object @@ -28,7 +27,6 @@ PoolingParams = object SamplingParams = object Request = object - BlockRepositionRequest = object @bc_linter_include @@ -183,9 +181,6 @@ class SchedulerOutput: # freed from the encoder cache. free_encoder_mm_hashes: list[str] - # for KV cache repositioning (as part of Block-Attention implementation) - blocks_to_reposition: list[BlockRepositionRequest] - # Whether the scheduled requests have all the output tokens they # need to perform grammar bitmask computation. pending_structured_output_tokens: bool = False diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index a03c47e1c6..c17b19b58c 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -6,7 +6,6 @@ from collections.abc import Iterable from typing import Any -import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_transfer.kv_connector.factory import KVConnectorFactory @@ -22,11 +21,7 @@ EncoderCacheManager, compute_encoder_budget, ) -from vllm.v1.core.kv_cache_manager import ( - BlockRepositionRequest, - KVCacheBlocks, - KVCacheManager, -) +from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import ( CachedRequestData, @@ -359,7 +354,6 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests = create_request_queue(self.policy) # Next, schedule the WAITING requests. - blocks_to_reposition: list[BlockRepositionRequest] = [] if not preempted_reqs: while self.waiting and token_budget > 0: if len(self.running) == self.max_num_running_reqs: @@ -417,15 +411,6 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.get_computed_blocks(request) ) - # handle repositioning requests - if ( - envs.VLLM_V1_SPANS_ENABLED - and len(new_computed_blocks.blocks_to_reposition) > 0 - ): - blocks_to_reposition.extend( - new_computed_blocks.blocks_to_reposition - ) - # Get externally-cached tokens if using a KVConnector. if self.connector is not None: ext_tokens, load_kv_async = ( @@ -655,7 +640,6 @@ def schedule(self) -> SchedulerOutput: # the previous and the current steps. finished_req_ids=self.finished_req_ids, free_encoder_mm_hashes=self.encoder_cache_manager.get_freed_mm_hashes(), - blocks_to_reposition=blocks_to_reposition, ) # NOTE(Kuntai): this function is designed for multiple purposes: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index d92b541b14..1690db9f34 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -1413,6 +1413,15 @@ def _build_attention_metadata( # graph mode. blk_table.slot_mapping.gpu[total_num_scheduled_tokens:].fill_(-1) + if not hasattr(self, "rotate"): + if not isinstance(self.model.model.layers[0], PPMissingLayer): + self.rotate = self.model.model.layers[0].self_attn.rotary_emb + else: + for lay in self.model.model.layers: + if not isinstance(lay, PPMissingLayer): + self.rotate = lay.self_attn.rotary_emb + break + common_attn_metadata = CommonAttentionMetadata( query_start_loc=query_start_loc, query_start_loc_cpu=query_start_loc_cpu, @@ -1430,6 +1439,7 @@ def _build_attention_metadata( causal=True, encoder_seq_lens=encoder_seq_lens, dcp_local_seq_lens=dcp_local_seq_lens, + cos_sin_cache=self.rotate.cos_sin_cache, ) if self.speculative_config and spec_decode_common_attn_metadata is None: @@ -2181,116 +2191,6 @@ def _pool( pooler_output=pooler_output, ) - def _perform_repositioning(self, scheduler_output: "SchedulerOutput") -> None: - """ - Repositions KV cache blocks based on the scheduler's instructions. - - This method handles the repositioning of attention block - vectors in the KV cache when their positions in the KV cache - and in the prompt differ. It applies rotary embedding - transformations to adjust the positions. - - Args: - scheduler_output: The output from the scheduler containing blocks - to reposition. - """ - blocks_to_reposition = scheduler_output.blocks_to_reposition - if envs.VLLM_V1_SPANS_DEBUG: - ts_repo = time.time() - repo_count = len(blocks_to_reposition) - if len(blocks_to_reposition) > 0: - bs = 512 - for i in range(0, len(blocks_to_reposition), bs): - repo_batch = blocks_to_reposition[i : i + bs] - self._repositionings_handler(repo_batch) - if envs.VLLM_V1_SPANS_DEBUG and repo_count > 0: - torch.cuda.synchronize() - t_repo = time.time() - ts_repo - print( - f"[SPANS -> gpu_model_runner] repositioning" - f" speed: {repo_count / t_repo:.2f} (blocks/s)" - f" (total {repo_count})" - ) - - @torch.inference_mode() - def _repositionings_handler(self, blocks_to_reposition): - num_repos = len(blocks_to_reposition) - if envs.VLLM_V1_SPANS_DEBUG and num_repos > 0: - print(f"[SPANS -> gpu_model_runner] reposition block count: {num_repos}") - if not envs.VLLM_V1_SPANS_DISABLE_REPOSITION: - kvc_positions = torch.tensor( - [d.kvc_pos for d in blocks_to_reposition], - dtype=torch.long, - device=self.kv_caches[0].device, - ).unsqueeze(-1) - prt_positions = torch.tensor( - [d.prompt_pos for d in blocks_to_reposition], - dtype=torch.long, - device=self.kv_caches[0].device, - ).unsqueeze(-1) - block_ids = torch.tensor( - [d.block_id for d in blocks_to_reposition], - dtype=torch.long, - device=self.kv_caches[0].device, - ) - - # (self.kv_caches shape): - # [nlay, kv, maxblocks, blocksize, headcount, headsize] - concerned_vectors = [ - x[0, block_ids, :, :, :] for x in self.kv_caches - ] # -> [nlay, blockids, blocksize, headcount, headsize] - bids, bsize, hcount, hsize = concerned_vectors[0].shape - - template_tensor = torch.arange( - bsize, dtype=torch.long, device=self.kv_caches[0].device - ).unsqueeze(0) - pos_depos = kvc_positions + template_tensor - pos_repos = prt_positions + template_tensor - - # precision highly affects the outputs - PRECISION = torch.float32 - DEF_PRECISION = self.kv_caches[0].dtype - - # do the rotation - # note: PPMissingLayer is for pipeline parallel support - if not hasattr(self, "rotate"): - if not isinstance(self.model.model.layers[0], PPMissingLayer): - self.rotate = self.model.model.layers[0].self_attn.rotary_emb - else: - for lay in self.model.model.layers: - if not isinstance(lay, PPMissingLayer): - self.rotate = lay.self_attn.rotary_emb - break - assert pos_depos.shape[0] == concerned_vectors[0].shape[0] - - if num_repos > 100: - for i, k_vectors in enumerate(concerned_vectors): - k_vectors_tmp, _ = self.rotate.forward_native( - pos_depos, k_vectors.to(PRECISION), invert_rotation_angle=True - ) - k_vectors_tmp, _ = self.rotate.forward_native( - pos_repos, k_vectors_tmp - ) - self.kv_caches[i][0, block_ids, ...] = k_vectors_tmp.to( - DEF_PRECISION - ) - else: - nlays = len(concerned_vectors) - kvecs = torch.cat(concerned_vectors, dim=0).to(PRECISION) - k_vectors_tmp, _ = self.rotate.forward_native( - pos_depos.repeat(nlays, 1), kvecs, invert_rotation_angle=True - ) - k_vectors_tmp, _ = self.rotate.forward_native( - pos_repos.repeat(nlays, 1), k_vectors_tmp - ) - k_vectors_tmp = k_vectors_tmp.reshape( - nlays, *concerned_vectors[0].shape - ) - for i in range(len(self.kv_caches)): - self.kv_caches[i][0, block_ids, ...] = k_vectors_tmp[i].to( - DEF_PRECISION - ) - def _get_num_input_tokens(self, num_scheduled_tokens: int) -> int: if ( self.compilation_config.cudagraph_mode != CUDAGraphMode.NONE @@ -2637,10 +2537,6 @@ def execute_model( ) num_scheduled_tokens = scheduler_output.total_num_scheduled_tokens with record_function_or_nullcontext("Preprocess"): - # NOTE(tdoublep): should this be inside context below? - # handle repositioning requests - self._perform_repositioning(scheduler_output) - with self.synchronize_input_prep(): # Update persistent batch states. self._update_states(scheduler_output)