Skip to content

nixl: Add fp8 KV cache scale transfer in P/D disaggregated inference#4

Draft
Copilot wants to merge 3 commits into
mainfrom
copilot/support-fp8-kv-cache
Draft

nixl: Add fp8 KV cache scale transfer in P/D disaggregated inference#4
Copilot wants to merge 3 commits into
mainfrom
copilot/support-fp8-kv-cache

Conversation

Copy link
Copy Markdown

Copilot AI commented Feb 26, 2026

The nixl connector transfers raw KV cache bytes without conveying fp8 quantization scales. When the Decode (D) instance uses different scales than Prefill (P) — particularly with dynamic fp8 (calculate_kv_scales=True) — D dequantizes received fp8 values with wrong scales, corrupting attention output.

Changes

NixlAgentMetadata schema (NIXL_CONNECTOR_VERSION bumped to 3):

  • Added fp8_k_scales: list[float] and fp8_v_scales: list[float] — per-layer quantization scales, empty for non-fp8 or when P/D share calibrated weights-loaded scales

P-side (prefill):

  • _collect_fp8_kv_scales() — reads _k_scale/_v_scale from all modules in compilation_config.static_forward_context that expose those attributes, sorted by module name for stable cross-instance ordering
  • register_kv_caches() — now calls _collect_fp8_kv_scales() and embeds scales in NixlAgentMetadata
  • rebuild_xfer_handshake_metadata() — rebuilds handshake payload with current scales; call this after the first forward pass when using calculate_kv_scales=True (dynamic fp8) to propagate computed scales to subsequent D-side handshakes

D-side (decode):

  • _apply_fp8_kv_scales() — applies received scales to local attention layers; validates k/v list length parity and local cache dtype before writing; no-op for empty scale lists (bf16, or static fp8 where both sides share weights)
  • add_remote_agent() — calls _apply_fp8_kv_scales() after handshake completes

Notes on dynamic fp8 timing

For calculate_kv_scales=True, scales are computed lazily during the first forward pass — after register_kv_caches() runs. The handshake metadata built at registration time will carry 1.0 defaults. Call rebuild_xfer_handshake_metadata() after the first forward pass to ensure D-side handshakes receive correct computed scales.


🔒 GitHub Advanced Security automatically protects Copilot coding agent pull requests. You can protect all pull requests by enabling Advanced Security for your repositories. Learn more about Advanced Security.

Copilot AI and others added 2 commits February 26, 2026 02:17
Co-authored-by: zhenwei-intel <109187816+zhenwei-intel@users.noreply.github.com>
…rch.isclose

Co-authored-by: zhenwei-intel <109187816+zhenwei-intel@users.noreply.github.com>
Copilot AI changed the title [WIP] Add support for fp8 kv cache transmission nixl: Add fp8 KV cache scale transfer in P/D disaggregated inference Feb 26, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants