Skip to content

feat: NixlConnector support for Neuron disaggregated inference #26

@dmvevents

Description

@dmvevents

Summary

Enable NixlConnector (NIXL KV cache transfer) support in vllm-neuron for disaggregated prefill/decode serving where a Trainium worker handles prefill and an NVIDIA GPU worker handles decode. KV cache is transferred from Neuron device memory through CPU DRAM via NIXL LIBFABRIC over EFA RDMA.

This follows the same host-buffer pattern already established by TPU and XPU in vLLM core.

Proven by POC: Trainium prefill (24ms) -> NIXL RDMA transfer -> H100 decode (2139ms total, Qwen3-0.6B) via Dynamo with KV-aware routing.

Required Changes (7 patches)

We have a working POC with 7 patches. Each has been filed as an individual issue:

  1. `get_kv_cache_shape()`fix: get_kv_cache_shape() raises NotImplementedError — blocks NixlConnector #20 — Implement the shape method in `neuron_attn.py` (returns `(2, num_blocks, num_kv_heads, block_size, head_size)`)
  2. `get_kv_connector_handshake_metadata()`fix: NeuronWorker missing get_kv_connector_handshake_metadata() #21 — Add the handshake method to `NeuronWorker`
  3. KV CPU registrationfix: KV cache tensors on privateuseone device cannot be registered with NIXL #22 — Move KV tensors to CPU before NIXL registration in model runner
  4. **`get_attn_backend_cls(kwargs)`fix: get_attn_backend_cls() missing **kwargs — breaks vLLM 0.16+ #23 — Add `**kwargs` for vLLM 0.16+ forward compatibility
  5. `max_model_len` attributefix: scheduler_config.max_model_len AttributeError on vLLM 0.16 #24 — Fix to read from `model_config` instead of `scheduler_config`
  6. NIXL context blockingfix: NIXL context manager blocks forward pass when no KV transfers pending #25 — Guard context manager entry when no KV transfers pending
  7. `max_concurrent_batches=1` — Force synchronous execution for Neuron (no async batch overlap)

Configuration

```bash

Prefill (Trainium)

--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_buffer_device":"cpu","kv_connector_extra_config":{"backends":["LIBFABRIC"]}}'

Decode (NVIDIA GPU)

--kv-transfer-config '{"kv_connector":"NixlConnector","kv_role":"kv_both","kv_buffer_device":"cuda","kv_connector_extra_config":{"backends":["LIBFABRIC"]}}'
```

Testing

  • Hardware: trn1.32xlarge (prefill) + p5.48xlarge/H100 (decode), same VPC/AZ, EFA RDMA
  • Model: Qwen/Qwen3-0.6B
  • Result: 24.2ms prefill on Trainium, 2139ms total with 30 tokens decoded on H100

Known Limitations

  • Prefix caching must be disabled on Trainium due to neuronx-cc NCC_ITEN404 bug (aws-neuron/aws-neuron-sdk#1304)
  • `max-num-seqs` must be 1 on Trainium to avoid the same compiler bug

Related

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions