nixl: Add fp8 KV cache scale transfer in P/D disaggregated inference#4
Draft
Copilot wants to merge 3 commits into
Draft
nixl: Add fp8 KV cache scale transfer in P/D disaggregated inference#4Copilot wants to merge 3 commits into
Copilot wants to merge 3 commits into
Conversation
Co-authored-by: zhenwei-intel <109187816+zhenwei-intel@users.noreply.github.com>
…rch.isclose Co-authored-by: zhenwei-intel <109187816+zhenwei-intel@users.noreply.github.com>
Copilot
AI
changed the title
[WIP] Add support for fp8 kv cache transmission
nixl: Add fp8 KV cache scale transfer in P/D disaggregated inference
Feb 26, 2026
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
The nixl connector transfers raw KV cache bytes without conveying fp8 quantization scales. When the Decode (D) instance uses different scales than Prefill (P) — particularly with dynamic fp8 (
calculate_kv_scales=True) — D dequantizes received fp8 values with wrong scales, corrupting attention output.Changes
NixlAgentMetadataschema (NIXL_CONNECTOR_VERSIONbumped to 3):fp8_k_scales: list[float]andfp8_v_scales: list[float]— per-layer quantization scales, empty for non-fp8 or when P/D share calibrated weights-loaded scalesP-side (prefill):
_collect_fp8_kv_scales()— reads_k_scale/_v_scalefrom all modules incompilation_config.static_forward_contextthat expose those attributes, sorted by module name for stable cross-instance orderingregister_kv_caches()— now calls_collect_fp8_kv_scales()and embeds scales inNixlAgentMetadatarebuild_xfer_handshake_metadata()— rebuilds handshake payload with current scales; call this after the first forward pass when usingcalculate_kv_scales=True(dynamic fp8) to propagate computed scales to subsequent D-side handshakesD-side (decode):
_apply_fp8_kv_scales()— applies received scales to local attention layers; validates k/v list length parity and local cache dtype before writing; no-op for empty scale lists (bf16, or static fp8 where both sides share weights)add_remote_agent()— calls_apply_fp8_kv_scales()after handshake completesNotes on dynamic fp8 timing
For
calculate_kv_scales=True, scales are computed lazily during the first forward pass — afterregister_kv_caches()runs. The handshake metadata built at registration time will carry 1.0 defaults. Callrebuild_xfer_handshake_metadata()after the first forward pass to ensure D-side handshakes receive correct computed scales.🔒 GitHub Advanced Security automatically protects Copilot coding agent pull requests. You can protect all pull requests by enabling Advanced Security for your repositories. Learn more about Advanced Security.