Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
252 changes: 209 additions & 43 deletions vllm/attention/ops/triton_unified_attention.py

Large diffs are not rendered by default.

6 changes: 0 additions & 6 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,6 @@
VLLM_V1_SPANS_DEBUG: bool = False
VLLM_V1_SPANS_TOKEN_PLUS: int = -1
VLLM_V1_SPANS_TOKEN_CROSS: int = -1
VLLM_V1_SPANS_DISABLE_REPOSITION: bool = False


def get_default_cache_root():
Expand Down Expand Up @@ -1504,11 +1503,6 @@ def get_vllm_port() -> int | None:
"VLLM_V1_SPANS_TOKEN_CROSS": lambda: int(
os.environ.get("VLLM_V1_SPANS_TOKEN_CROSS", "-1")
),
# for block-attention, detected spans will be loaded but not repositioned
"VLLM_V1_SPANS_DISABLE_REPOSITION": lambda: os.environ.get(
"VLLM_V1_SPANS_DISABLE_REPOSITION", "False"
)
== "True",
}

# --8<-- [end:env-vars-definition]
Expand Down
8 changes: 3 additions & 5 deletions vllm/model_executor/layers/rotary_embedding/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import torch

import vllm.envs as envs
from vllm.model_executor.custom_op import CustomOp

from .common import apply_rotary_emb_torch
Expand Down Expand Up @@ -106,15 +107,12 @@ def forward_native(
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
invert_rotation_angle: bool = False, # <- to unrope kv's
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""A PyTorch-native implementation of forward()."""
positions = positions.flatten()
num_tokens = positions.shape[0]
cos_sin = self.cos_sin_cache.index_select(0, positions)
cos, sin = cos_sin.chunk(2, dim=-1)
if invert_rotation_angle:
sin = -sin

query_shape = query.shape
query = query.view(num_tokens, -1, self.head_size)
Expand All @@ -124,7 +122,7 @@ def forward_native(
query = torch.cat((query_rot, query_pass), dim=-1).reshape(query_shape)

# key may be None in some cases, e.g. cross-layer KV sharing
if key is not None:
if key is not None and not envs.VLLM_V1_SPANS_ENABLED:
key_shape = key.shape
key = key.view(num_tokens, -1, self.head_size)
key_rot = key[..., : self.rotary_dim]
Expand Down Expand Up @@ -159,7 +157,7 @@ def forward_cuda(
ops.rotary_embedding(
positions,
query,
key,
None if envs.VLLM_V1_SPANS_ENABLED else key,
self.head_size,
self.cos_sin_cache,
self.is_neox_style,
Expand Down
1 change: 0 additions & 1 deletion vllm/model_executor/layers/rotary_embedding/mrope.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ def forward_native(
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
invert_rotation_angle: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""PyTorch-native implementation equivalent to forward().

Expand Down
12 changes: 12 additions & 0 deletions vllm/v1/attention/backends/triton_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import torch

import vllm.envs as envs
from vllm.attention.backends.abstract import (
AttentionBackend,
AttentionImpl,
Expand Down Expand Up @@ -59,6 +60,8 @@ class TritonAttentionMetadata:
prefix_kv_lens: torch.Tensor | None
suffix_kv_lens: torch.Tensor | None

cos_sin_cache: torch.Tensor | None = None

# Optional aot scheduling
scheduler_metadata: torch.Tensor | None = None
prefix_scheduler_metadata: torch.Tensor | None = None
Expand Down Expand Up @@ -127,6 +130,11 @@ def build(
suffix_kv_lens = None
prefix_scheduler_metadata = None

if envs.VLLM_V1_SPANS_ENABLED:
cos_sin_cache = common_attn_metadata.cos_sin_cache
else:
cos_sin_cache = None
Comment thread
tdoublep marked this conversation as resolved.

attn_metadata = TritonAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
Expand All @@ -141,6 +149,7 @@ def build(
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
prefix_scheduler_metadata=prefix_scheduler_metadata,
cos_sin_cache=cos_sin_cache,
)
return attn_metadata

Expand Down Expand Up @@ -343,6 +352,8 @@ def forward(

descale_shape = (cu_seqlens_q.shape[0] - 1, key.shape[1])

cos_sin_cache = attn_metadata.cos_sin_cache

unified_attention(
q=query[:num_actual_tokens],
k=key_cache,
Expand All @@ -363,6 +374,7 @@ def forward(
v_descale=layer._v_scale.expand(descale_shape),
sinks=self.sinks,
output_scale=output_scale,
cos_sin_cache=cos_sin_cache,
)

return output
2 changes: 2 additions & 0 deletions vllm/v1/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ class CommonAttentionMetadata:
block_table_tensor: torch.Tensor
slot_mapping: torch.Tensor

cos_sin_cache: torch.Tensor | None = None

causal: bool = True

# Needed by FastPrefillAttentionBuilder
Expand Down
71 changes: 4 additions & 67 deletions vllm/v1/core/block_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from collections.abc import Iterable, Sequence
from typing import Any

import vllm.envs as envs
from vllm.distributed.kv_events import (
MEDIUM_GPU,
AllBlocksCleared,
Expand Down Expand Up @@ -242,8 +241,6 @@ def cache_full_blocks(
if new_hashes is not None:
new_hashes.append(maybe_convert_block_hash(block_hash))

self._set_block_positions(new_full_blocks, blocks, request)

if self.enable_kv_cache_events:
if num_cached_blocks == 0:
parent_block_hash: ExternalBlockHash | None = None
Expand All @@ -269,58 +266,6 @@ def cache_full_blocks(
)
)

def _set_block_positions(
self,
new_full_blocks: list[KVCacheBlock],
blocks: list[KVCacheBlock],
request: Request,
):
"""Sets the positions of new full blocks in the KV cache.

This function assigns positions to newly filled blocks based
on their order within the provided block list. The position
corresponds to the location embedded in K vectors (if using RoPE)
in the KV cache and is critical for maintaining correct alignment,
especially when prompt positions differ between requests.

Args:
new_full_blocks: List of KVCacheBlock objects that have been newly
filled and require position assignment.
blocks: List of all blocks associated with the current request,
used to determine the order in which positions are assigned.
request: The Request object containing token information for
debugging purposes.

Note:
When VLLM_V1_SPANS_DEBUG is enabled, this function includes
debug logging that prints each block's tokens, to help
debug span-related workflows.
"""
pos = 0
for blk in blocks:
if blk in new_full_blocks:
blk.position = pos
if envs.VLLM_V1_SPANS_DEBUG:
# this prints the tokens assigned to a new block
# in the KV cache
blk_tks = request.all_token_ids[pos : pos + 16]
assert blk.block_hash is not None
bhash = str(blk.block_hash)[:4] if blk.block_hash else None
print(
"[SPANS -> block_pool] assigning to pos",
pos,
"with hash",
bhash,
"block: ",
blk_tks,
)
pos += 16
if envs.VLLM_V1_SPANS_DEBUG:
print(
"[SPANS -> block_pool] assigned block count now ->",
len([b for b in self.blocks if b._block_hash]),
)

def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]:
"""Get new blocks from the free block pool.

Expand Down Expand Up @@ -413,19 +358,11 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None:
blocks_list = list(ordered_blocks)
for block in blocks_list:
block.ref_cnt -= 1
# remove duplicates (blocks can now appear twice)
block_ids = set()
blocks_list_filtered = []
for block in blocks_list:
if block.block_id not in block_ids:
blocks_list_filtered.append(block)
block_ids.add(block.block_id)
# Remove duplicates while preserving order
dedup_bl = list({block.block_id: block for
block in blocks_list}.values())
self.free_block_queue.append_n(
[
block
for block in blocks_list_filtered
if block.ref_cnt == 0 and not block.is_null
]
[block for block in dedup_bl if block.ref_cnt == 0 and not block.is_null]
)

def reset_prefix_cache(self) -> bool:
Expand Down
Loading