diff --git a/README.md b/README.md index 4e03df758c..0764bc9c6b 100644 --- a/README.md +++ b/README.md @@ -16,6 +16,12 @@ Easy, fast, and cheap LLM serving for everyone --- +## What is the purpose of this fork? + +This is a fork of vLLM which we are using to develop support for *span semantics*. + +--- + *Latest News* 🔥 - [2025/08] We hosted [vLLM Shenzhen Meetup](https://mp.weixin.qq.com/s/k8ZBO1u2_2odgiKWH_GVTQ) focusing on the ecosystem around vLLM! Please find the meetup slides [here](https://drive.google.com/drive/folders/1Ua2SVKVSu-wp5vou_6ElraDt2bnKhiEA). diff --git a/examples/offline_inference/spans/spans.py b/examples/offline_inference/spans/spans.py new file mode 100644 index 0000000000..ebe42f7ba7 --- /dev/null +++ b/examples/offline_inference/spans/spans.py @@ -0,0 +1,192 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +import os +import time + +# to ensure deterministic behaviour +os.environ["TOKENIZERS_PARALLELISM"] = "False" + +# standard imports +from vllm import LLM, SamplingParams +from vllm.inputs import TokensPrompt + + +# helper functions +def pad(toklist, padtok): + return toklist[:-1] + [padtok] * ((16 - len(toklist)) % 16) + toklist[-1:] + + +def avg(list_of_numbers): + return sum(list_of_numbers) / max(len(list_of_numbers), 1) + + +def wrap(prompt): + if isinstance(prompt[0], list): + return [TokensPrompt(prompt_token_ids=p) for p in prompt] + return TokensPrompt(prompt_token_ids=prompt) + + +def initialize_vllm( + model, temp=0.6, logprobs=None, max_toks=32768, max_generated_toks=1 +): + # boot up vLLM + samp_params_preload = SamplingParams(temperature=temp, max_tokens=1) + samp_params_generate = SamplingParams( + temperature=temp, max_tokens=max_generated_toks, logprobs=logprobs + ) + llm = LLM( + model=model, + gpu_memory_utilization=0.9, + enforce_eager=True, # <- so it boots faster + block_size=16, + ) + tok = llm.get_tokenizer() + tok_fun = lambda x: tok.convert_tokens_to_ids(tok.tokenize(x)) + return samp_params_preload, samp_params_generate, tok_fun, llm + + +def main(): + model_names = [ + "ldsjmdy/Tulu3-Block-FT", # <- finetuned to handle block-attention + "ldsjmdy/Tulu3-RAG", # <- baseline + ] + model_name = model_names[0] + + # tokens that need to be set to perform block-attention + PAD_TOK = 27 # <- "<" + SPAN_TOK_PLUS = 10 # <- "+" + SPAN_TOK_CROSS = 31 # <- "@" + + # vLLM-specific env vars + + # enables block attention + # -> when this line is not commented, we expect a speedup + # in the execution of the last two .generate calls + os.environ["VLLM_V1_SPANS_ENABLED"] = "True" + + # the token that tells vLLM "this is the beginning of a span" + os.environ["VLLM_V1_SPANS_TOKEN_PLUS"] = str(SPAN_TOK_PLUS) + + # token that tells vLLM: + # "from here on, recompute KV vectors if any previous tokens differ" + os.environ["VLLM_V1_SPANS_TOKEN_CROSS"] = str(SPAN_TOK_CROSS) + + # will print every step of the span process if set to true + os.environ["VLLM_V1_SPANS_DEBUG"] = "True" + + # will disable the adjustment of positional encodings when a KV cache + # block is loaded to a different position than it was stored + # -> when this line is not commented, + # spans overlap in their positional encodings + os.environ["VLLM_V1_SPANS_DISABLE_REPOSITION"] = "True" + + # general env vars + + # now we instantiate the model + samp_params_preload, samp_params_generate, tok, llm = initialize_vllm( + model_name, max_generated_toks=128, max_toks=10_000, temp=0.0 + ) + + # components of the prompt template + prefix = pad( + tok( + "<|system|>\nYou are an intelligent AI assistant. " + "Please answer questions based on the user's instructions. " + "Below are some reference documents that may help you in " + "answering the user's question." + ), + PAD_TOK, + ) + midfx = [SPAN_TOK_CROSS] + tok( + "<|user|>\nPlease write a high-quality answer for the " + "given question using only the provided search documents " + "(some of which might be irrelevant).\nQuestion: " + ) + postfx = tok("""\n<|assistant|>\n""") + + print("---->", postfx) + + # task-specific documents + doc_a = pad( + [SPAN_TOK_PLUS] + + tok( + "[0] The Template-Assisted " + "Selective Epitaxy (TASE) method, developed at " + "IBM Research Europe – Zurich, permits to " + "create a homogeneous integration route for " + "various semiconductor materials which is " + "compatible with the CMOS process." + ), + PAD_TOK, + ) + + doc_b = pad( + [SPAN_TOK_PLUS] + + tok( + "[1] The dominant sequence transduction " + "models are based on complex recurrent or " + "convolutional neural networks in an encoder-decoder " + "configuration. " + ), + PAD_TOK, + ) + + # # alt-docs (purely to check performance on longer documents) + """ + a_toks = tok("Sequence Transduction Models") + b_toks = tok("Template-Assisted Selective Epitaxy") + doc_a = pad( + [SPAN_TOK_PLUS] + + [a_toks[idx % len(a_toks)] for idx in range(10_000)], + PAD_TOK, + ) + doc_b = pad( + [SPAN_TOK_PLUS] + + [b_toks[idx % len(a_toks)] for idx in range(10_000)], + PAD_TOK, + ) + """ + + # user query + query = ( + midfx + + tok( + "Tell me which one concerns deep learning. " + "Indicate your answer with a number in brackets." + ) + + postfx + ) + + # preload documents + ts_pre = time.time() + llm.generate( + [wrap(doc_a), wrap(doc_b), wrap(prefix)], sampling_params=samp_params_preload + ) + te_pre = time.time() - ts_pre + + ts_gen = time.time() + + # this now will load prefix, doc_a, doc_b, + # from the KV cache regardless of the order + model_response_1 = llm.generate( + wrap(prefix + doc_a + doc_b + query), + sampling_params=samp_params_generate, + use_tqdm=False, + ) + + # this should also run faster: + model_response_2 = llm.generate( + wrap(prefix + doc_b + doc_a + query), + sampling_params=samp_params_generate, + use_tqdm=False, + ) + + te_gen = time.time() - ts_gen + + print(f"doc preload time / TTFT : {te_pre:.4f} / {te_gen:.4f} (s)") + print("model output 1 was:", model_response_1[0].outputs[0].text) + print("model output 2 was:", model_response_2[0].outputs[0].text) + + +if __name__ == "__main__": + main() diff --git a/vllm/envs.py b/vllm/envs.py index 8d199da45b..5d88c1d39f 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -172,6 +172,12 @@ VLLM_GPT_OSS_HARMONY_SYSTEM_INSTRUCTIONS: bool = False VLLM_CUSTOM_SCOPES_FOR_PROFILING: bool = False VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES: bool = True + # spans vars + VLLM_V1_SPANS_ENABLED: bool = False + VLLM_V1_SPANS_DEBUG: bool = False + VLLM_V1_SPANS_TOKEN_PLUS: int = -1 + VLLM_V1_SPANS_TOKEN_CROSS: int = -1 + VLLM_V1_SPANS_DISABLE_REPOSITION: bool = False def get_default_cache_root(): @@ -1221,6 +1227,31 @@ def get_vllm_port() -> Optional[int]: # raw bytes. Defaults to True for backward compatibility. "VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES": lambda: bool(int(os.getenv("VLLM_KV_EVENTS_USE_INT_BLOCK_HASHES", "1"))), + + # whether to enable block-attention (span detection, fan-in, repositioning) + "VLLM_V1_SPANS_ENABLED": + lambda: os.environ.get("VLLM_V1_SPANS_ENABLED", "False") == "True", + + # whether to print details pertaining to the block-attention + # implementation + "VLLM_V1_SPANS_DEBUG": + lambda: os.environ.get("VLLM_V1_SPANS_DEBUG", "False") == "True", + + # for block-attention, the token that will be used in order to + # indicate the beginning of a span (needed for it to work) + "VLLM_V1_SPANS_TOKEN_PLUS": + lambda: int(os.environ.get("VLLM_V1_SPANS_TOKEN_PLUS", "-1")), + + # for block-attention, a token that signals the beginning of a + # span which needs to depend on all previous tokens + "VLLM_V1_SPANS_TOKEN_CROSS": + lambda: int(os.environ.get("VLLM_V1_SPANS_TOKEN_CROSS", "-1")), + + # for block-attention, detected spans will be loaded but not repositioned + "VLLM_V1_SPANS_DISABLE_REPOSITION": + lambda: os.environ.get("VLLM_V1_SPANS_DISABLE_REPOSITION", "False" + ) == "True", + } # --8<-- [end:env-vars-definition] diff --git a/vllm/model_executor/layers/rotary_embedding/base.py b/vllm/model_executor/layers/rotary_embedding/base.py index be25e90abf..69eef7ba9e 100644 --- a/vllm/model_executor/layers/rotary_embedding/base.py +++ b/vllm/model_executor/layers/rotary_embedding/base.py @@ -63,6 +63,7 @@ def forward_native( query: torch.Tensor, key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, + invert_rotation_angle: bool = False # <- to unrope kv's ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """A PyTorch-native implementation of forward().""" if offsets is not None: @@ -71,6 +72,8 @@ def forward_native( num_tokens = positions.shape[0] cos_sin = self.cos_sin_cache.index_select(0, positions) cos, sin = cos_sin.chunk(2, dim=-1) + if invert_rotation_angle: + sin = -sin query_shape = query.shape query = query.view(num_tokens, -1, self.head_size) diff --git a/vllm/model_executor/layers/rotary_embedding/mrope.py b/vllm/model_executor/layers/rotary_embedding/mrope.py index 0ab4bc5375..4601f147ea 100644 --- a/vllm/model_executor/layers/rotary_embedding/mrope.py +++ b/vllm/model_executor/layers/rotary_embedding/mrope.py @@ -230,6 +230,7 @@ def forward_native( query: torch.Tensor, key: Optional[torch.Tensor] = None, offsets: Optional[torch.Tensor] = None, + invert_rotation_angle: bool = False, ) -> tuple[torch.Tensor, Optional[torch.Tensor]]: """PyTorch-native implementation equivalent to forward(). diff --git a/vllm/v1/core/block_pool.py b/vllm/v1/core/block_pool.py index d1e1c1c8d0..9069a364db 100644 --- a/vllm/v1/core/block_pool.py +++ b/vllm/v1/core/block_pool.py @@ -4,6 +4,7 @@ from collections.abc import Iterable from typing import Optional +import vllm.envs as envs from vllm.distributed.kv_events import (MEDIUM_GPU, AllBlocksCleared, BlockRemoved, BlockStored, KVCacheEvent) @@ -145,6 +146,8 @@ def cache_full_blocks( if new_hashes is not None: new_hashes.append(maybe_convert_block_hash(block_hash)) + self._set_block_positions(new_full_blocks, blocks, request) + if self.enable_kv_cache_events: if num_cached_blocks == 0: parent_block_hash: Optional[ExternalBlockHash] = None @@ -167,6 +170,47 @@ def cache_full_blocks( medium=MEDIUM_GPU, )) + def _set_block_positions(self, new_full_blocks: list[KVCacheBlock], + blocks: list[KVCacheBlock], request: Request): + """Sets the positions of new full blocks in the KV cache. + + This function assigns positions to newly filled blocks based + on their order within the provided block list. The position + corresponds to the location embedded in K vectors (if using RoPE) + in the KV cache and is critical for maintaining correct alignment, + especially when prompt positions differ between requests. + + Args: + new_full_blocks: List of KVCacheBlock objects that have been newly + filled and require position assignment. + blocks: List of all blocks associated with the current request, + used to determine the order in which positions are assigned. + request: The Request object containing token information for + debugging purposes. + + Note: + When VLLM_V1_SPANS_DEBUG is enabled, this function includes + debug logging that prints each block's tokens, to help + debug span-related workflows. + """ + pos = 0 + for blk in blocks: + if blk in new_full_blocks: + blk.position = pos + if envs.VLLM_V1_SPANS_DEBUG: + # this prints the tokens assigned to a new block + # in the KV cache + blk_tks = request.all_token_ids[pos:pos + 16] + assert blk.block_hash is not None + bhash = str(abs(blk.block_hash.block_hash.hash_value) + )[:4] if blk.block_hash.block_hash else None + print('[SPANS -> block_pool] assigning to pos', pos, + 'with hash', bhash, 'block: ', blk_tks) + pos += 16 + if envs.VLLM_V1_SPANS_DEBUG: + print('[SPANS -> block_pool] assigned block count now ->', + len([b for b in self.blocks if b._block_hash])) + def get_new_blocks(self, num_blocks: int) -> list[KVCacheBlock]: """Get new blocks from the free block pool. @@ -261,8 +305,15 @@ def free_blocks(self, ordered_blocks: Iterable[KVCacheBlock]) -> None: blocks_list = list(ordered_blocks) for block in blocks_list: block.ref_cnt -= 1 + # remove duplicates (blocks can now appear twice) + block_ids = set() + blocks_list_filtered = [] + for block in blocks_list: + if block.block_id not in block_ids: + blocks_list_filtered.append(block) + block_ids.add(block.block_id) self.free_block_queue.append_n([ - block for block in blocks_list + block for block in blocks_list_filtered if block.ref_cnt == 0 and not block.is_null ]) diff --git a/vllm/v1/core/kv_cache_manager.py b/vllm/v1/core/kv_cache_manager.py index 3a0fbb5e5c..a6628cfc55 100644 --- a/vllm/v1/core/kv_cache_manager.py +++ b/vllm/v1/core/kv_cache_manager.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from typing import Literal, Optional, overload +import vllm.envs as envs from vllm.distributed.kv_events import KVCacheEvent from vllm.logger import init_logger from vllm.v1.core.kv_cache_coordinator import get_kv_cache_coordinator @@ -15,6 +16,13 @@ logger = init_logger(__name__) +@dataclass +class BlockRepositionRequest: + block_id: int + kvc_pos: int + prompt_pos: int + + @dataclass class KVCacheBlocks: """ @@ -23,6 +31,7 @@ class KVCacheBlocks: structure from the Scheduler. """ blocks: tuple[list[KVCacheBlock], ...] + blocks_to_reposition: list[BlockRepositionRequest] """ blocks[i][j] refers to the i-th kv_cache_group and the j-th block of tokens. We don't use block of tokens as the outer dimension because it assumes all @@ -35,7 +44,8 @@ def __add__(self, other: "KVCacheBlocks") -> "KVCacheBlocks": """Adds two KVCacheBlocks instances.""" return KVCacheBlocks( tuple(blk1 + blk2 - for blk1, blk2 in zip(self.blocks, other.blocks))) + for blk1, blk2 in zip(self.blocks, other.blocks)), + self.blocks_to_reposition + other.blocks_to_reposition) @overload def get_block_ids( @@ -78,7 +88,7 @@ def get_unhashed_block_ids(self) -> list[int]: def new_empty(self) -> "KVCacheBlocks": """Creates a new KVCacheBlocks instance with no blocks.""" - return KVCacheBlocks(tuple([] for _ in range(len(self.blocks)))) + return KVCacheBlocks(tuple([] for _ in range(len(self.blocks))), []) class KVCacheManager: @@ -180,6 +190,57 @@ def get_computed_blocks(self, computed_blocks, num_new_computed_tokens = ( self.coordinator.find_longest_cache_hit(request.block_hashes, max_cache_hit_length)) + if envs.VLLM_V1_SPANS_DEBUG: + print( + "[SPANS -> kv_cache_manager] here's the blocks hashed in " \ + "this request:", + [str(abs(b.hash_value))[:4] for b in request.block_hashes]) + kvcache_contents = [ + str(abs(b.block_hash.block_hash.hash_value))[:4] + if b.block_hash else None for b in self.block_pool.blocks + if b._block_hash + ] + if len(kvcache_contents) > 32: + kvcache_contents = kvcache_contents[:32] + [ + '... (too long to print it all)' + ] + print( + "[SPANS -> kv_cache_manager] here's the contents of the " \ + "kv cache:", + kvcache_contents) + print( + "[SPANS -> kv_cache_manager] here's the number of blocks " \ + "that hit the cache:", + [ + str(abs(b.block_hash.block_hash.hash_value))[:4] + if b.block_hash else None for b in computed_blocks[0] + ]) + + blocks_to_reposition = [] + if envs.VLLM_V1_SPANS_ENABLED: + # Spans does yet not support hybrid models + assert len(computed_blocks) == 1 + for i, b in enumerate(computed_blocks[0]): + prompt_pos = i * 16 + kvc_pos = b.position + if envs.VLLM_V1_SPANS_DEBUG: + print( + f"[SPANS -> kv_cache_manager] checking block " \ + f"{b.block_id} with prompot pos {prompt_pos} " \ + f"and kv pos {kvc_pos}" + ) + assert isinstance(kvc_pos, int) + if kvc_pos != prompt_pos: + if envs.VLLM_V1_SPANS_DEBUG: + print( + f"[SPANS -> kv_cache_manager] from pos: {kvc_pos} "\ + f"to prompt pos: {prompt_pos} repositioning needed" + ) + + blocks_to_reposition.append( + BlockRepositionRequest(b.block_id, kvc_pos, + prompt_pos)) + b.position = int(prompt_pos) if self.log_stats: assert self.prefix_cache_stats is not None @@ -187,7 +248,8 @@ def get_computed_blocks(self, self.prefix_cache_stats.queries += request.num_tokens self.prefix_cache_stats.hits += num_new_computed_tokens - return KVCacheBlocks(computed_blocks), num_new_computed_tokens + return KVCacheBlocks(computed_blocks, blocks_to_reposition),\ + num_new_computed_tokens def allocate_slots( self, @@ -290,7 +352,7 @@ def allocate_slots( # P/D: delay caching blocks if we have to recv from # remote. Update state for locally cached blocks. if not self.enable_caching or delay_cache_blocks: - return KVCacheBlocks(new_blocks) + return KVCacheBlocks(new_blocks, []) # NOTE(woosuk): We want to commit (cache) up to num_computed_tokens + # num_new_tokens, but must exclude "non-committable" tokens (e.g., @@ -300,7 +362,7 @@ def allocate_slots( request.num_tokens) self.coordinator.cache_blocks(request, num_tokens_to_cache) - return KVCacheBlocks(new_blocks) + return KVCacheBlocks(new_blocks, []) def free(self, request: Request) -> None: """Free the blocks allocated for the request. @@ -381,7 +443,7 @@ def take_events(self) -> list[KVCacheEvent]: def get_blocks(self, request_id: str) -> KVCacheBlocks: """Get the blocks of a request.""" - return KVCacheBlocks(self.coordinator.get_blocks(request_id)) + return KVCacheBlocks(self.coordinator.get_blocks(request_id), []) def get_block_ids(self, request_id: str) -> tuple[list[int], ...]: """Get the block ids of a request.""" @@ -394,5 +456,5 @@ def cache_blocks(self, request: Request, num_computed_tokens: int) -> None: def create_empty_block_list(self) -> KVCacheBlocks: """Creates a new KVCacheBlocks instance with no blocks.""" - return KVCacheBlocks(tuple([] - for _ in range(self.num_kv_cache_groups))) + return KVCacheBlocks( + tuple([] for _ in range(self.num_kv_cache_groups)), []) diff --git a/vllm/v1/core/kv_cache_utils.py b/vllm/v1/core/kv_cache_utils.py index 2c0eac3ddd..077b09f4ab 100644 --- a/vllm/v1/core/kv_cache_utils.py +++ b/vllm/v1/core/kv_cache_utils.py @@ -162,6 +162,8 @@ class KVCacheBlock: block_id: int # Reference count. ref_cnt: int = 0 + # Position (corresponds to positional encodings position) + position: Optional[int] = None # The hash key (block hash + group id) of the block, only available # when the block is full and cached. _block_hash: Optional[BlockHashWithGroupId] = None @@ -559,12 +561,38 @@ def hash_block_tokens( if not parent_block_hash: parent_block_hash = NONE_HASH + if envs.VLLM_V1_SPANS_ENABLED: + if envs.VLLM_V1_SPANS_TOKEN_PLUS == -1: + raise Exception( + '[SPANS -> kv_cache_utils]: span separator token undefined!') + # if a block starts with the span separator token, then its hash + # should be independent of previous tokens + firstok = curr_block_token_ids[0] + if firstok == envs.VLLM_V1_SPANS_TOKEN_PLUS: + if envs.VLLM_V1_SPANS_DEBUG: + print(f'[SPANS -> kv_cache_utils] detected span separator " \ + "token {envs.VLLM_V1_SPANS_TOKEN_PLUS} -> enable fan-in' + ) + parent_block_hash = NONE_HASH + curr_block_token_ids_tuple = tuple(curr_block_token_ids) return BlockHash( hash_function( (parent_block_hash, curr_block_token_ids_tuple, extra_keys))) +def recompute_token_handler( + block_first_token: int, tokens_up_to_block: list[int], + extra_keys: Union[tuple[Any, ...], + None]) -> Union[tuple[Any, ...], None]: + if envs.VLLM_V1_SPANS_ENABLED and \ + block_first_token == envs.VLLM_V1_SPANS_TOKEN_CROSS: + tok_tuple = tuple(tokens_up_to_block) + extra_keys = (*extra_keys, tok_tuple) if extra_keys \ + else tok_tuple + return extra_keys + + def get_request_block_hasher( block_size: int, caching_hash_fn: Callable[[Any], bytes], @@ -600,6 +628,8 @@ def request_block_hasher(request: Request) -> list[BlockHash]: # Compute the hash of the current block block_tokens = request.all_token_ids[start_token_idx:end_token_idx] + extra_keys = recompute_token_handler( + block_tokens[0], block_tokens[:start_token_idx], extra_keys) block_hash = hash_block_tokens(caching_hash_fn, prev_block_hash_value, block_tokens, extra_keys) diff --git a/vllm/v1/core/sched/output.py b/vllm/v1/core/sched/output.py index b5cd6c5c8a..eacfaaa1ff 100644 --- a/vllm/v1/core/sched/output.py +++ b/vllm/v1/core/sched/output.py @@ -16,6 +16,7 @@ from vllm.multimodal.inputs import MultiModalKwargsItem, PlaceholderRange from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams + from vllm.v1.core.kv_cache_manager import BlockRepositionRequest from vllm.v1.request import Request @@ -153,5 +154,8 @@ class SchedulerOutput: # the bitmask for the whole batch grammar_bitmask: Optional[npt.NDArray[np.int32]] + # for KV cache repositioning (as part of Block-Attention implementation) + blocks_to_reposition: list[BlockRepositionRequest] + # KV Cache Connector metadata. kv_connector_metadata: Optional[KVConnectorMetadata] = None diff --git a/vllm/v1/core/sched/scheduler.py b/vllm/v1/core/sched/scheduler.py index 2d40e96632..dbcb7ed39f 100644 --- a/vllm/v1/core/sched/scheduler.py +++ b/vllm/v1/core/sched/scheduler.py @@ -9,6 +9,7 @@ from collections.abc import Iterable from typing import Any, Optional, Union +import vllm.envs as envs from vllm.config import VllmConfig from vllm.distributed.kv_events import EventPublisherFactory, KVEventBatch from vllm.distributed.kv_transfer.kv_connector.factory import ( @@ -19,7 +20,8 @@ from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalRegistry from vllm.v1.core.encoder_cache_manager import (EncoderCacheManager, compute_encoder_budget) -from vllm.v1.core.kv_cache_manager import KVCacheBlocks, KVCacheManager +from vllm.v1.core.kv_cache_manager import (BlockRepositionRequest, + KVCacheBlocks, KVCacheManager) from vllm.v1.core.sched.interface import SchedulerInterface from vllm.v1.core.sched.output import (CachedRequestData, NewRequestData, SchedulerOutput) @@ -330,6 +332,7 @@ def schedule(self) -> SchedulerOutput: skipped_waiting_requests = create_request_queue(self.policy) # Next, schedule the WAITING requests. + blocks_to_reposition: list[BlockRepositionRequest] = [] if not preempted_reqs: while self.waiting and token_budget > 0: if len(self.running) == self.max_num_running_reqs: @@ -381,6 +384,12 @@ def schedule(self) -> SchedulerOutput: self.kv_cache_manager.get_computed_blocks( request) + # handle repositioning requests + if envs.VLLM_V1_SPANS_ENABLED and \ + len(new_computed_blocks.blocks_to_reposition) > 0: + blocks_to_reposition.extend( + new_computed_blocks.blocks_to_reposition) + # Get externally-cached tokens if using a KVConnector. if self.connector is not None: num_external_computed_tokens, load_kv_async = ( @@ -589,6 +598,7 @@ def schedule(self) -> SchedulerOutput: get_freed_mm_hashes(), structured_output_request_ids=structured_output_request_ids, grammar_bitmask=grammar_bitmask, + blocks_to_reposition=blocks_to_reposition, ) # NOTE(Kuntai): this function is designed for multiple purposes: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 549c5dd2bb..adb79f1e62 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -44,6 +44,7 @@ supports_transcription) from vllm.model_executor.models.interfaces_base import ( VllmModelForPooling, is_pooling_model, is_text_generation_model) +from vllm.model_executor.models.utils import PPMissingLayer from vllm.multimodal import MULTIMODAL_REGISTRY from vllm.multimodal.inputs import (BatchedTensorInputs, MultiModalKwargsItem, PlaceholderRange) @@ -1587,6 +1588,112 @@ def _pool( kv_connector_output=kv_connector_output, ) + def _perform_repositioning(self, + scheduler_output: "SchedulerOutput") -> None: + """ + Repositions KV cache blocks based on the scheduler's instructions. + + This method handles the repositioning of attention block + vectors in the KV cache when their positions in the KV cache + and in the prompt differ. It applies rotary embedding + transformations to adjust the positions. + + Args: + scheduler_output: The output from the scheduler containing blocks + to reposition. + """ + blocks_to_reposition = scheduler_output.blocks_to_reposition + if envs.VLLM_V1_SPANS_DEBUG: + ts_repo = time.time() + repo_count = len(blocks_to_reposition) + if len(blocks_to_reposition) < 600: + self._repositionings_handler(blocks_to_reposition) + else: + bs = 400 + for i in range(0, len(blocks_to_reposition), bs): + repo_batch = blocks_to_reposition[i:i+bs] + self._repositionings_handler(repo_batch) + if envs.VLLM_V1_SPANS_DEBUG and repo_count > 0: + torch.cuda.synchronize() + t_repo = time.time() - ts_repo + print(f'[SPANS -> gpu_model_runner] repositioning' \ + f' speed: {repo_count/t_repo:.2f} (blocks/s)'\ + f' (total {repo_count})') + + @torch.inference_mode() + def _repositionings_handler(self, blocks_to_reposition): + num_repos = len(blocks_to_reposition) + if envs.VLLM_V1_SPANS_DEBUG and num_repos > 0: + print( + f'[SPANS -> gpu_model_runner] ' \ + f'reposition block count: {num_repos}' + ) + if not envs.VLLM_V1_SPANS_DISABLE_REPOSITION and num_repos > 0: + kvc_positions = torch.tensor( + [d.kvc_pos for d in blocks_to_reposition], + dtype=torch.long, + device=self.kv_caches[0].device).unsqueeze(-1) + prt_positions = torch.tensor( + [d.prompt_pos for d in blocks_to_reposition], + dtype=torch.long, + device=self.kv_caches[0].device).unsqueeze(-1) + block_ids = torch.tensor( + [d.block_id for d in blocks_to_reposition], + dtype=torch.long, + device=self.kv_caches[0].device) + + # (self.kv_caches shape): + # [nlay, kv, maxblocks, blocksize, headcount, headsize] + concerned_vectors = [ + x[0, block_ids, :, :, :] for x in self.kv_caches + ] # -> [nlay, blockids, blocksize, headcount, headsize] + bids, bsize, hcount, hsize = concerned_vectors[0].shape + + template_tensor = torch.arange( + bsize, dtype=torch.long, + device=self.kv_caches[0].device).unsqueeze(0) + pos_depos = kvc_positions + template_tensor + pos_repos = prt_positions + template_tensor + + # precision highly affects the outputs + PRECISION = torch.float32 + DEF_PRECISION = self.kv_caches[0].dtype + + # do the rotation + # note: PPMissingLayer is for pipeline parallel support + if not hasattr(self, 'rotate'): + if not isinstance(self.model.model.layers[0], PPMissingLayer): + self.rotate = self.model.model.layers[ + 0].self_attn.rotary_emb + else: + for lay in self.model.model.layers: + if not isinstance(lay, PPMissingLayer): + self.rotate = lay.self_attn.rotary_emb + break + assert pos_depos.shape[0] == concerned_vectors[0].shape[0] + + if num_repos > 100: + for i, k_vectors in enumerate(concerned_vectors): + k_vectors_tmp, _ = self.rotate.forward_native( + pos_depos, + k_vectors.to(PRECISION), + invert_rotation_angle=True) + k_vectors_tmp, _ = self.rotate.forward_native( + pos_repos, k_vectors_tmp) + self.kv_caches[i][0, block_ids, ...] = \ + k_vectors_tmp.to(DEF_PRECISION) + else: + k_vectors_tmp, _ = self.rotate.forward_native( + pos_depos, + torch.cat([k.unsqueeze(0) for k in concerned_vectors], + dim=0).to(PRECISION), + invert_rotation_angle=True) + k_vectors_tmp, _ = self.rotate.forward_native( + pos_repos, k_vectors_tmp) + for i in range(len(self.kv_caches)): + self.kv_caches[i][0, block_ids, ...] = \ + k_vectors_tmp[i].to(DEF_PRECISION) + def _preprocess( self, scheduler_output: "SchedulerOutput", @@ -1850,6 +1957,10 @@ def execute_model( intermediate_tensors: Optional[IntermediateTensors] = None, ) -> Union[ModelRunnerOutput, AsyncModelRunnerOutput, IntermediateTensors]: with record_function_or_nullcontext("Preprocess"): + + # handle repositioning requests + self._perform_repositioning(scheduler_output) + self._update_states(scheduler_output) if not scheduler_output.total_num_scheduled_tokens: if not has_kv_transfer_group():