diff --git a/tests/v1/kv_connector/unit/test_nixl_connector.py b/tests/v1/kv_connector/unit/test_nixl_connector.py index 1975d2226073..ffdada5f1b8f 100644 --- a/tests/v1/kv_connector/unit/test_nixl_connector.py +++ b/tests/v1/kv_connector/unit/test_nixl_connector.py @@ -2326,3 +2326,249 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario) remote_tp_size=1, expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID, ) + + +# --------------------------------------------------------------------------- +# FP8 KV cache scale transfer tests +# --------------------------------------------------------------------------- + + +class _FakeAttentionLayer: + """Minimal fake attention layer with FP8 scale attributes.""" + + def __init__(self, k_scale: float = 1.0, v_scale: float = 1.0): + import torch + + self._k_scale = torch.tensor(k_scale, dtype=torch.float32) + self._v_scale = torch.tensor(v_scale, dtype=torch.float32) + self._k_scale_float = k_scale + self._v_scale_float = v_scale + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_collect_fp8_kv_scales_non_fp8(default_vllm_config, dist_init): + """_collect_fp8_kv_scales returns empty lists for non-fp8 cache dtype.""" + vllm_config = create_vllm_config(cache_dtype="auto") + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + worker = connector.connector_worker + + k_scales, v_scales = worker._collect_fp8_kv_scales() + + assert k_scales == [] + assert v_scales == [] + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_collect_fp8_kv_scales_fp8(default_vllm_config, dist_init): + """_collect_fp8_kv_scales returns correct per-layer scales for fp8.""" + vllm_config = create_vllm_config(cache_dtype="fp8") + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + worker = connector.connector_worker + + # Populate static_forward_context with fake attention layers. + layer_a = _FakeAttentionLayer(k_scale=0.5, v_scale=0.25) + layer_b = _FakeAttentionLayer(k_scale=1.5, v_scale=2.0) + # Use sorted-order names to match expected collection order. + worker.vllm_config.compilation_config.static_forward_context = { + "model.layers.0.self_attn": layer_a, + "model.layers.1.self_attn": layer_b, + # Non-attention layer that should be ignored. + "model.layers.0.mlp": object(), + } + + k_scales, v_scales = worker._collect_fp8_kv_scales() + + assert k_scales == [0.5, 1.5] + assert v_scales == [0.25, 2.0] + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_apply_fp8_kv_scales(default_vllm_config, dist_init): + """_apply_fp8_kv_scales updates local attention layers with remote scales.""" + vllm_config = create_vllm_config(cache_dtype="fp8") + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + worker = connector.connector_worker + + # Set up local attention layers with default scales. + layer_a = _FakeAttentionLayer(k_scale=1.0, v_scale=1.0) + layer_b = _FakeAttentionLayer(k_scale=1.0, v_scale=1.0) + worker.vllm_config.compilation_config.static_forward_context = { + "model.layers.0.self_attn": layer_a, + "model.layers.1.self_attn": layer_b, + } + + # Simulate metadata received from remote P-side with calibrated scales. + remote_metadata = NixlAgentMetadata( + engine_id="remote-engine", + agent_metadata=b"", + kv_caches_base_addr=[], + device_id=0, + num_blocks=1, + block_lens=[], + kv_cache_layout="HND", + block_size=16, + fp8_k_scales=[0.5, 1.5], + fp8_v_scales=[0.25, 2.0], + ) + + worker._apply_fp8_kv_scales(remote_metadata) + + # Verify scales were updated. + import torch + + assert torch.isclose(layer_a._k_scale, torch.tensor(0.5)) + assert torch.isclose(layer_a._v_scale, torch.tensor(0.25)) + assert layer_a._k_scale_float == 0.5 + assert layer_a._v_scale_float == 0.25 + + assert torch.isclose(layer_b._k_scale, torch.tensor(1.5)) + assert torch.isclose(layer_b._v_scale, torch.tensor(2.0)) + assert layer_b._k_scale_float == 1.5 + assert layer_b._v_scale_float == 2.0 + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_apply_fp8_kv_scales_no_op_for_empty(default_vllm_config, dist_init): + """_apply_fp8_kv_scales is a no-op when remote metadata has empty scales.""" + vllm_config = create_vllm_config(cache_dtype="fp8") + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + worker = connector.connector_worker + + layer_a = _FakeAttentionLayer(k_scale=0.3, v_scale=0.7) + worker.vllm_config.compilation_config.static_forward_context = { + "model.layers.0.self_attn": layer_a, + } + + remote_metadata = NixlAgentMetadata( + engine_id="remote-engine", + agent_metadata=b"", + kv_caches_base_addr=[], + device_id=0, + num_blocks=1, + block_lens=[], + kv_cache_layout="HND", + block_size=16, + # Empty scales → no update should happen. + fp8_k_scales=[], + fp8_v_scales=[], + ) + + worker._apply_fp8_kv_scales(remote_metadata) + + # Original scales should be unchanged. + import torch + + assert torch.isclose(layer_a._k_scale, torch.tensor(0.3)) + assert torch.isclose(layer_a._v_scale, torch.tensor(0.7)) + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_register_kv_caches_includes_fp8_scales(default_vllm_config, dist_init): + """register_kv_caches embeds fp8 scales into the handshake metadata.""" + from vllm.config import set_current_vllm_config + + vllm_config = create_vllm_config(cache_dtype="fp8") + with set_current_vllm_config(vllm_config): + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + worker = connector.connector_worker + + # Inject fake fp8 attention layers BEFORE registering KV caches so + # that _collect_fp8_kv_scales picks them up. + layer_a = _FakeAttentionLayer(k_scale=0.5, v_scale=0.25) + layer_b = _FakeAttentionLayer(k_scale=1.5, v_scale=2.0) + worker.vllm_config.compilation_config.static_forward_context = { + "model.layers.0.self_attn": layer_a, + "model.layers.1.self_attn": layer_b, + } + + # Register KV caches (standard HND layout from FlashAttentionBackend). + from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend + + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + ) + kv_caches = { + "layer0": torch.zeros(*kv_cache_shape, dtype=torch.float8_e4m3fn), + } + connector.register_kv_caches(kv_caches) + + # Decode the metadata to verify fp8 scales were embedded. + handshake_metadata = connector.get_handshake_metadata() + assert handshake_metadata is not None + agent_meta = msgspec.msgpack.Decoder(NixlAgentMetadata).decode( + handshake_metadata.agent_metadata_bytes + ) + + assert agent_meta.fp8_k_scales == [0.5, 1.5] + assert agent_meta.fp8_v_scales == [0.25, 2.0] + + +@patch( + "vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper", + FakeNixlWrapper, +) +def test_rebuild_xfer_handshake_metadata_updates_fp8_scales( + default_vllm_config, dist_init +): + """rebuild_xfer_handshake_metadata refreshes fp8 scales in the payload.""" + from vllm.config import set_current_vllm_config + + vllm_config = create_vllm_config(cache_dtype="fp8") + with set_current_vllm_config(vllm_config): + connector = NixlConnector(vllm_config, KVConnectorRole.WORKER) + worker = connector.connector_worker + + # Initial scales are default 1.0. + layer_a = _FakeAttentionLayer(k_scale=1.0, v_scale=1.0) + worker.vllm_config.compilation_config.static_forward_context = { + "model.layers.0.self_attn": layer_a, + } + + from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend + + kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( + num_blocks=2, block_size=16, num_kv_heads=4, head_size=64 + ) + kv_caches = { + "layer0": torch.zeros(*kv_cache_shape, dtype=torch.float8_e4m3fn), + } + connector.register_kv_caches(kv_caches) + + # Simulate dynamic fp8 scale computation (first forward pass). + layer_a._k_scale.fill_(0.3) + layer_a._v_scale.fill_(0.7) + layer_a._k_scale_float = 0.3 + layer_a._v_scale_float = 0.7 + + # Rebuild the metadata to capture the updated scales. + worker.rebuild_xfer_handshake_metadata() + + # Verify the rebuilt metadata has updated scales. + handshake_metadata = connector.get_handshake_metadata() + assert handshake_metadata is not None + agent_meta = msgspec.msgpack.Decoder(NixlAgentMetadata).decode( + handshake_metadata.agent_metadata_bytes + ) + + assert len(agent_meta.fp8_k_scales) == 1 + assert torch.isclose( + torch.tensor(agent_meta.fp8_k_scales[0]), torch.tensor(0.3) + ) + assert torch.isclose( + torch.tensor(agent_meta.fp8_v_scales[0]), torch.tensor(0.7) + ) diff --git a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py index b3f2ae703fdf..609deccb9bb0 100644 --- a/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +++ b/vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py @@ -13,7 +13,7 @@ from collections import defaultdict from collections.abc import Iterator from concurrent.futures import Future, ThreadPoolExecutor -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any import msgspec @@ -80,8 +80,9 @@ # Version History: # 1: Initial version with compatibility checking # 2: Add remote_request_id to kv_transfer_params +# 3: Add fp8_k_scales and fp8_v_scales to NixlAgentMetadata for fp8 KV cache # -NIXL_CONNECTOR_VERSION: int = 2 +NIXL_CONNECTOR_VERSION: int = 3 GET_META_MSG = b"get_meta_msg" @@ -152,6 +153,13 @@ class NixlAgentMetadata: block_lens: list[int] kv_cache_layout: str block_size: int + # FP8 KV cache quantization scales (one float per attention layer). + # Empty list means non-fp8 or scales already match (e.g. static + # calibration loaded from model weights on both P and D). + # Non-empty list is populated for dynamic fp8 (calculate_kv_scales=True) + # or any case where P's scales may differ from D's default. + fp8_k_scales: list[float] = field(default_factory=list) + fp8_v_scales: list[float] = field(default_factory=list) @dataclass @@ -1477,6 +1485,7 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): assert len(self.block_window_per_layer) == self.num_layers # After KV Caches registered, listen for new connections. + fp8_k_scales, fp8_v_scales = self._collect_fp8_kv_scales() agent_metadata = NixlAgentMetadata( engine_id=self.engine_id, agent_metadata=self.nixl_wrapper.get_agent_metadata(), @@ -1488,6 +1497,8 @@ def register_kv_caches(self, kv_caches: dict[str, torch.Tensor]): if not self.use_host_buffer else self.host_buffer_kv_cache_layout, block_size=self.block_size, + fp8_k_scales=fp8_k_scales, + fp8_v_scales=fp8_v_scales, ) # Wrap metadata in payload with hash for defensive decoding assert self.compat_hash is not None @@ -1552,6 +1563,151 @@ def register_local_xfer_handler( # NIXL_INIT_AGENT to be used for preparations of local descs. return self.nixl_wrapper.prep_xfer_dlist("NIXL_INIT_AGENT", descs), blocks_data + def _collect_fp8_kv_scales(self) -> tuple[list[float], list[float]]: + """ + Collect per-layer FP8 KV cache quantization scales from the attention + layers in the static forward context. + + Returns: + A tuple (fp8_k_scales, fp8_v_scales) where each is a list of + per-layer float scale values sorted by layer name. Both lists + are empty when the KV cache dtype is not FP8. + + Note: + For dynamic FP8 (calculate_kv_scales=True), scales are computed + lazily during the first forward pass. If this method is called + before any forward pass has run, the returned values will be 1.0 + (the default initialization value). Call + rebuild_xfer_handshake_metadata() after the first forward pass to + update the handshake metadata with the computed scales. + """ + if not self.cache_config.cache_dtype.startswith("fp8"): + return [], [] + + static_ctx = self.vllm_config.compilation_config.static_forward_context + if not static_ctx: + logger.debug( + "static_forward_context is empty; fp8 scales will not be " + "included in handshake metadata." + ) + return [], [] + + fp8_k_scales: list[float] = [] + fp8_v_scales: list[float] = [] + for name in sorted(static_ctx.keys()): + module = static_ctx[name] + if hasattr(module, "_k_scale") and hasattr(module, "_v_scale"): + fp8_k_scales.append(module._k_scale.item()) + fp8_v_scales.append(module._v_scale.item()) + + logger.debug( + "Collected %d fp8 KV scale pairs for handshake metadata.", + len(fp8_k_scales), + ) + return fp8_k_scales, fp8_v_scales + + def rebuild_xfer_handshake_metadata(self) -> None: + """ + Rebuild the handshake metadata with the current FP8 KV scales. + + This should be called after dynamic FP8 scales have been computed + (i.e., after the first forward pass when calculate_kv_scales=True). + Subsequent D-side handshakes will then receive the correct scales. + """ + if self.xfer_handshake_metadata is None: + return + + # Re-decode the existing metadata to preserve all other fields. + metadata_decoder = msgspec.msgpack.Decoder(NixlAgentMetadata) + agent_metadata = metadata_decoder.decode( + self.xfer_handshake_metadata.agent_metadata_bytes + ) + + fp8_k_scales, fp8_v_scales = self._collect_fp8_kv_scales() + agent_metadata.fp8_k_scales = fp8_k_scales + agent_metadata.fp8_v_scales = fp8_v_scales + + encoder = msgspec.msgpack.Encoder() + assert self.compat_hash is not None + self.xfer_handshake_metadata = NixlHandshakePayload( + compatibility_hash=self.compat_hash, + agent_metadata_bytes=encoder.encode(agent_metadata), + ) + logger.debug( + "Rebuilt handshake metadata with %d fp8 scale pairs.", + len(fp8_k_scales), + ) + + def _apply_fp8_kv_scales(self, nixl_agent_meta: NixlAgentMetadata) -> None: + """ + Apply the FP8 KV cache quantization scales received from the remote + (P-side) agent to the local (D-side) attention layers. + + The scales are applied only when: + - The remote metadata contains non-empty fp8 scale lists. + - The local KV cache dtype is FP8. + - The number of remote scales matches the number of local attention + layers that expose ``_k_scale`` / ``_v_scale`` attributes. + + Args: + nixl_agent_meta: Metadata received from the remote P-side agent. + """ + remote_k_scales = nixl_agent_meta.fp8_k_scales + remote_v_scales = nixl_agent_meta.fp8_v_scales + + if not remote_k_scales: + return + + if len(remote_k_scales) != len(remote_v_scales): + logger.warning( + "Cannot apply remote fp8 scales: fp8_k_scales length (%d) " + "does not match fp8_v_scales length (%d); ignoring.", + len(remote_k_scales), + len(remote_v_scales), + ) + return + + if not self.cache_config.cache_dtype.startswith("fp8"): + logger.warning( + "Received fp8 KV scales from remote but local cache dtype is " + "%s; ignoring received scales.", + self.cache_config.cache_dtype, + ) + return + + static_ctx = self.vllm_config.compilation_config.static_forward_context + local_attn_names = sorted( + name + for name, module in static_ctx.items() + if hasattr(module, "_k_scale") and hasattr(module, "_v_scale") + ) + + if len(local_attn_names) != len(remote_k_scales): + logger.warning( + "Cannot apply remote fp8 scales: remote has %d scale entries " + "but local has %d attention layers with fp8 scales.", + len(remote_k_scales), + len(local_attn_names), + ) + return + + for name, k_scale, v_scale in zip( + local_attn_names, remote_k_scales, remote_v_scales + ): + module = static_ctx[name] + module._k_scale.fill_(k_scale) + module._v_scale.fill_(v_scale) + if hasattr(module, "_k_scale_float"): + module._k_scale_float = k_scale + if hasattr(module, "_v_scale_float"): + module._v_scale_float = v_scale + + logger.debug( + "Applied %d fp8 KV scale pairs from remote agent %s.", + len(remote_k_scales), + nixl_agent_meta.engine_id, + ) + def add_remote_agent( self, nixl_agent_meta: NixlAgentMetadata, @@ -1743,6 +1899,12 @@ def add_remote_agent( self.register_local_xfer_handler(nixl_agent_meta.block_size)[0] ) + # Apply fp8 KV cache quantization scales from the remote P-side agent. + # This is a no-op when remote metadata contains no fp8 scales (e.g. + # bf16 KV cache) or when P and D already share the same calibrated + # scales (static fp8 from quantized model weights). + self._apply_fp8_kv_scales(nixl_agent_meta) + return remote_agent_name def _validate_remote_agent_handshake(