Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 246 additions & 0 deletions tests/v1/kv_connector/unit/test_nixl_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2326,3 +2326,249 @@ def test_handshake_decode_errors(default_vllm_config, dist_init, error_scenario)
remote_tp_size=1,
expected_engine_id=FakeNixlConnectorWorker.REMOTE_ENGINE_ID,
)


# ---------------------------------------------------------------------------
# FP8 KV cache scale transfer tests
# ---------------------------------------------------------------------------


class _FakeAttentionLayer:
"""Minimal fake attention layer with FP8 scale attributes."""

def __init__(self, k_scale: float = 1.0, v_scale: float = 1.0):
import torch

self._k_scale = torch.tensor(k_scale, dtype=torch.float32)
self._v_scale = torch.tensor(v_scale, dtype=torch.float32)
self._k_scale_float = k_scale
self._v_scale_float = v_scale


@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_collect_fp8_kv_scales_non_fp8(default_vllm_config, dist_init):
"""_collect_fp8_kv_scales returns empty lists for non-fp8 cache dtype."""
vllm_config = create_vllm_config(cache_dtype="auto")
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
worker = connector.connector_worker

k_scales, v_scales = worker._collect_fp8_kv_scales()

assert k_scales == []
assert v_scales == []


@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_collect_fp8_kv_scales_fp8(default_vllm_config, dist_init):
"""_collect_fp8_kv_scales returns correct per-layer scales for fp8."""
vllm_config = create_vllm_config(cache_dtype="fp8")
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
worker = connector.connector_worker

# Populate static_forward_context with fake attention layers.
layer_a = _FakeAttentionLayer(k_scale=0.5, v_scale=0.25)
layer_b = _FakeAttentionLayer(k_scale=1.5, v_scale=2.0)
# Use sorted-order names to match expected collection order.
worker.vllm_config.compilation_config.static_forward_context = {
"model.layers.0.self_attn": layer_a,
"model.layers.1.self_attn": layer_b,
# Non-attention layer that should be ignored.
"model.layers.0.mlp": object(),
}

k_scales, v_scales = worker._collect_fp8_kv_scales()

assert k_scales == [0.5, 1.5]
assert v_scales == [0.25, 2.0]


@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_apply_fp8_kv_scales(default_vllm_config, dist_init):
"""_apply_fp8_kv_scales updates local attention layers with remote scales."""
vllm_config = create_vllm_config(cache_dtype="fp8")
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
worker = connector.connector_worker

# Set up local attention layers with default scales.
layer_a = _FakeAttentionLayer(k_scale=1.0, v_scale=1.0)
layer_b = _FakeAttentionLayer(k_scale=1.0, v_scale=1.0)
worker.vllm_config.compilation_config.static_forward_context = {
"model.layers.0.self_attn": layer_a,
"model.layers.1.self_attn": layer_b,
}

# Simulate metadata received from remote P-side with calibrated scales.
remote_metadata = NixlAgentMetadata(
engine_id="remote-engine",
agent_metadata=b"",
kv_caches_base_addr=[],
device_id=0,
num_blocks=1,
block_lens=[],
kv_cache_layout="HND",
block_size=16,
fp8_k_scales=[0.5, 1.5],
fp8_v_scales=[0.25, 2.0],
)

worker._apply_fp8_kv_scales(remote_metadata)

# Verify scales were updated.
import torch

assert torch.isclose(layer_a._k_scale, torch.tensor(0.5))
assert torch.isclose(layer_a._v_scale, torch.tensor(0.25))
assert layer_a._k_scale_float == 0.5
assert layer_a._v_scale_float == 0.25

assert torch.isclose(layer_b._k_scale, torch.tensor(1.5))
assert torch.isclose(layer_b._v_scale, torch.tensor(2.0))
assert layer_b._k_scale_float == 1.5
assert layer_b._v_scale_float == 2.0


@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_apply_fp8_kv_scales_no_op_for_empty(default_vllm_config, dist_init):
"""_apply_fp8_kv_scales is a no-op when remote metadata has empty scales."""
vllm_config = create_vllm_config(cache_dtype="fp8")
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
worker = connector.connector_worker

layer_a = _FakeAttentionLayer(k_scale=0.3, v_scale=0.7)
worker.vllm_config.compilation_config.static_forward_context = {
"model.layers.0.self_attn": layer_a,
}

remote_metadata = NixlAgentMetadata(
engine_id="remote-engine",
agent_metadata=b"",
kv_caches_base_addr=[],
device_id=0,
num_blocks=1,
block_lens=[],
kv_cache_layout="HND",
block_size=16,
# Empty scales → no update should happen.
fp8_k_scales=[],
fp8_v_scales=[],
)

worker._apply_fp8_kv_scales(remote_metadata)

# Original scales should be unchanged.
import torch

assert torch.isclose(layer_a._k_scale, torch.tensor(0.3))
assert torch.isclose(layer_a._v_scale, torch.tensor(0.7))


@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_register_kv_caches_includes_fp8_scales(default_vllm_config, dist_init):
"""register_kv_caches embeds fp8 scales into the handshake metadata."""
from vllm.config import set_current_vllm_config

vllm_config = create_vllm_config(cache_dtype="fp8")
with set_current_vllm_config(vllm_config):
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
worker = connector.connector_worker

# Inject fake fp8 attention layers BEFORE registering KV caches so
# that _collect_fp8_kv_scales picks them up.
layer_a = _FakeAttentionLayer(k_scale=0.5, v_scale=0.25)
layer_b = _FakeAttentionLayer(k_scale=1.5, v_scale=2.0)
worker.vllm_config.compilation_config.static_forward_context = {
"model.layers.0.self_attn": layer_a,
"model.layers.1.self_attn": layer_b,
}

# Register KV caches (standard HND layout from FlashAttentionBackend).
from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend

kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
kv_caches = {
"layer0": torch.zeros(*kv_cache_shape, dtype=torch.float8_e4m3fn),
}
connector.register_kv_caches(kv_caches)

# Decode the metadata to verify fp8 scales were embedded.
handshake_metadata = connector.get_handshake_metadata()
assert handshake_metadata is not None
agent_meta = msgspec.msgpack.Decoder(NixlAgentMetadata).decode(
handshake_metadata.agent_metadata_bytes
)

assert agent_meta.fp8_k_scales == [0.5, 1.5]
assert agent_meta.fp8_v_scales == [0.25, 2.0]


@patch(
"vllm.distributed.kv_transfer.kv_connector.v1.nixl_connector.NixlWrapper",
FakeNixlWrapper,
)
def test_rebuild_xfer_handshake_metadata_updates_fp8_scales(
default_vllm_config, dist_init
):
"""rebuild_xfer_handshake_metadata refreshes fp8 scales in the payload."""
from vllm.config import set_current_vllm_config

vllm_config = create_vllm_config(cache_dtype="fp8")
with set_current_vllm_config(vllm_config):
connector = NixlConnector(vllm_config, KVConnectorRole.WORKER)
worker = connector.connector_worker

# Initial scales are default 1.0.
layer_a = _FakeAttentionLayer(k_scale=1.0, v_scale=1.0)
worker.vllm_config.compilation_config.static_forward_context = {
"model.layers.0.self_attn": layer_a,
}

from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend

kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape(
num_blocks=2, block_size=16, num_kv_heads=4, head_size=64
)
kv_caches = {
"layer0": torch.zeros(*kv_cache_shape, dtype=torch.float8_e4m3fn),
}
connector.register_kv_caches(kv_caches)

# Simulate dynamic fp8 scale computation (first forward pass).
layer_a._k_scale.fill_(0.3)
layer_a._v_scale.fill_(0.7)
layer_a._k_scale_float = 0.3
layer_a._v_scale_float = 0.7

# Rebuild the metadata to capture the updated scales.
worker.rebuild_xfer_handshake_metadata()

# Verify the rebuilt metadata has updated scales.
handshake_metadata = connector.get_handshake_metadata()
assert handshake_metadata is not None
agent_meta = msgspec.msgpack.Decoder(NixlAgentMetadata).decode(
handshake_metadata.agent_metadata_bytes
)

assert len(agent_meta.fp8_k_scales) == 1
assert torch.isclose(
torch.tensor(agent_meta.fp8_k_scales[0]), torch.tensor(0.3)
)
assert torch.isclose(
torch.tensor(agent_meta.fp8_v_scales[0]), torch.tensor(0.7)
)
Loading