diff --git a/docs/MX_RL_OVERVIEW.md b/docs/MX_RL_OVERVIEW.md
new file mode 100644
index 00000000..d3f01aa3
--- /dev/null
+++ b/docs/MX_RL_OVERVIEW.md
@@ -0,0 +1,389 @@
+# ModelExpress for RL Post-Training — Design Overview
+
+**Last Updated**: April 2026
+
+This document explains how ModelExpress (MX) accelerates reinforcement learning (RL) post-training by replacing slow weight transfer mechanisms with GPU-to-GPU RDMA. It covers the general design, the PRIME-RL proof of concept (working), and the verl integration plan.
+
+---
+
+## The Problem: Weight Sync in RL
+
+RL post-training runs a continuous loop: generate text (inference) → score it → update the model (training) → sync the updated weights back to inference → repeat. The weight sync step is a bottleneck:
+
+| Method | How | Time (3 GB model) | Limitation |
+|--------|-----|-------------------|-----------|
+| Filesystem | Serialize to disk → read back | 30-60s | Disk I/O, serialization |
+| NCCL broadcast | Collective over network | 2-8s | Requires static groups, `max_async_level=1` |
+| **MX + NIXL RDMA** | **GPU reads directly from GPU** | **<1s** | **None for async training** |
+
+ModelExpress eliminates the serialization-to-disk bottleneck while preserving async training capability. NCCL forces synchronous operation; the filesystem is slow. MX + NIXL gives both speed and flexibility.
+
+---
+
+## How ModelExpress Works for RL
+
+### Architecture
+
+```mermaid
+sequenceDiagram
+ participant T as Trainer GPU
+ participant M as MX Server (gRPC + Redis)
+ participant I as Inference GPU
+
+ Note over T: 1. optimizer.step() weights updated in VRAM
+
+ T->>M: 2. publish_weights() tensor addrs + NIXL metadata via gRPC
+
+ I->>M: 3. poll_for_source() "any new weights?"
+ M-->>I: 4. get_metadata() addrs + NIXL connection info
+
+ Note over T,I: 5. NIXL RDMA READ — GPU-to-GPU data transfer (inference reads from trainer's VRAM, CPU not involved)
+ T-->>I: weight bytes (RDMA)
+
+ Note over I: 6. model.load_weights() inference applies weights
+```
+
+**MX Server** stores only metadata — tensor names, GPU memory addresses, NIXL agent connection info, version tracking. It never touches weight data. The bulk transfer is a one-sided RDMA read between GPUs.
+
+### Client Library
+
+Two classes in `modelexpress_client/python/modelexpress/`:
+
+**`MxTrainingPublisher`** (trainer side):
+```python
+publisher = MxTrainingPublisher("trainer-rank-0", device_id=0, mx_server_url="mx-server:8001")
+publisher.initialize(model_name="Qwen/Qwen2.5-1.5B")
+
+# After optimizer.step():
+publisher.publish_weights(model.state_dict(), step=training_step)
+publisher.mark_ready()
+```
+
+- Registers GPU tensors with NIXL (once, on first call — addresses are stable across steps)
+- Publishes tensor metadata + NIXL agent info to MX Server via gRPC
+- Marks version as READY so inference can discover it
+
+**`MxRefitReceiver`** (inference side):
+```python
+receiver = MxRefitReceiver("inference-rank-0", device_id=0, mx_server_url="mx-server:8001")
+receiver.initialize()
+
+source = receiver.poll_for_source(model_name="Qwen/Qwen2.5-1.5B")
+for name, tensor in receiver.receive_weights_scratch(source):
+ ... # feed into model.load_weights()
+```
+
+- Queries MX Server for available weight sources
+- Allocates scratch GPU buffers matching the source tensor layout
+- NIXL RDMA reads weight data directly from the trainer's GPU
+- Yields `(name, tensor)` pairs for the inference engine's weight loader
+
+### The Scratch-Buffer Approach
+
+RL trainers publish weights in HuggingFace format (339 separate tensors for Qwen2.5-1.5B). Inference engines like vLLM use fused tensors internally (198 parameters — Q/K/V merged into `qkv_proj`, gate/up merged into `gate_up_proj`). Names and shapes don't match.
+
+Solution: RDMA into temporary scratch buffers matching the trainer's layout, then feed through `model.load_weights()` which handles the name mapping and tensor fusion. The RDMA layer stays simple (just move bytes); the inference engine handles semantics.
+
+---
+
+## PRIME-RL POC (Working)
+
+### What is PRIME-RL?
+
+An async-first RL framework by PrimeIntellect with three separate processes:
+- **Trainer** — FSDP2 training (GPU pod)
+- **Orchestrator** — Coordination, scoring (CPU pod)
+- **Inference** — vLLM rollout generation (GPU pod)
+
+No Ray dependency. Raw Kubernetes pods.
+
+### Architecture
+
+```mermaid
+graph LR
+ subgraph cluster["GKE GB200 Cluster"]
+ direction TB
+
+ subgraph control["Control Plane · CPU"]
+ direction LR
+ orch["Orchestrator"]
+ mx["MX Server + Redis"]
+ end
+
+ subgraph gpus["GPU Pods · 4x GB200 each"]
+ direction LR
+
+ subgraph tp["Trainer Pod"]
+ direction TB
+ fsdp["FSDP2 Training"]
+ pub["MX Publisher"]
+ nt(["NIXL Agent"])
+ fsdp --> pub --> nt
+ end
+
+ subgraph ip["Inference Pod"]
+ direction TB
+ vllm["vLLM Server"]
+ recv["MX Receiver"]
+ ni(["NIXL Agent"])
+ ni --> recv --> vllm
+ end
+ end
+
+ orch -. "HTTP rollouts" .-> vllm
+ orch -. "HTTP update_weights" .-> recv
+ pub -- "gRPC publish" --> mx
+ recv -- "gRPC discover" --> mx
+ nt <== "RDMA RoCE · 261-330 Gbps" ==> ni
+ end
+
+ style cluster fill:#1a1a2e,stroke:#16213e,color:#e0e0e0
+ style control fill:#1a1a2e,stroke:#533483,color:#e0e0e0
+ style gpus fill:#0f3460,stroke:#533483,color:#e0e0e0
+ style tp fill:#162447,stroke:#533483,color:#e0e0e0
+ style ip fill:#162447,stroke:#533483,color:#e0e0e0
+ style fsdp fill:#533483,stroke:#e94560,color:#fff
+ style vllm fill:#533483,stroke:#e94560,color:#fff
+ style orch fill:#533483,stroke:#e94560,color:#fff
+ style pub fill:#1b5e20,stroke:#4caf50,color:#fff
+ style recv fill:#1b5e20,stroke:#4caf50,color:#fff
+ style mx fill:#1b5e20,stroke:#4caf50,color:#fff
+ style nt fill:#2e7d32,stroke:#66bb6a,color:#fff
+ style ni fill:#2e7d32,stroke:#66bb6a,color:#fff
+```
+
+Green = ModelExpress / NIXL components. Purple = existing framework components.
+
+### Integration Summary
+
+| Item | Details |
+|------|---------|
+| **Repo** | `github.com/KavinKrishnan/prime-rl`, branch `kavink/mx-weight-broadcast` |
+| **MX client branch** | `github.com/ai-dynamo/modelexpress`, branch `kavink/RL` |
+| **New files in PRIME-RL** | `broadcast/modelexpress.py` (trainer), `worker/modelexpress.py` (inference), `Dockerfile.mx-arm64` |
+| **Modified files** | 8 files, 93 lines added (configs, routes, factory, client helper) |
+| **Cluster** | GKE DGXCloud, GB200 ARM64, 4 GPUs/node, RoCE networking |
+| **Model** | Qwen/Qwen2.5-1.5B (3.55 GB BF16) |
+
+### How It Works
+
+1. Trainer runs `optimizer.step()`, gathers FSDP2 shards, calls `MxTrainingPublisher.publish_weights()`
+2. Orchestrator detects new weights (via filesystem marker), tells inference to update
+3. Inference calls `MxRefitReceiver.receive_weights_scratch()` — NIXL RDMA pulls weights from trainer GPU
+4. Scratch tensors reshaped using safetensors header shapes, fed through `model.load_weights()`
+5. vLLM resumes serving with updated model
+
+### Results
+
+| Metric | Filesystem | RDMA | Speedup |
+|--------|-----------|------|---------|
+| Weight update time | ~55s | <1s | **55x** |
+| Transfer bandwidth | ~60 MB/s | 261-330 Gbps | ~500x |
+| CPU involvement | Full | None | Eliminated |
+
+### Key Issues Resolved
+
+| Issue | Root Cause | Fix |
+|-------|-----------|-----|
+| `NIXL_ERR_NOT_ALLOWED` | Wrong `UCX_TLS` value | Match TRT-LLM: `self,sm,rc,cuda_copy,gdr_copy,tcp` |
+| 800 KB metadata blob | Re-registering tensors every step | Register once, reuse cached metadata |
+| `REMOTE_DISCONNECT` | UCX 1.18 + missing IMEX channels | UCX 1.20 + DRA `compute-domain-channel` claims |
+| Tensor name mismatch | HF names vs vLLM fused names | Scratch-buffer approach + `model.load_weights()` |
+| Shape assertion | 1D scratch vs 2D expected | Read shapes from safetensors header |
+
+### Cluster Config (GCP GB200)
+
+```yaml
+hostNetwork: true
+privileged: true
+resourceClaims:
+ - name: compute-domain-channel
+ resourceClaimTemplateName: kavin-compute-domain-channel
+env:
+ UCX_TLS: "self,sm,rc,cuda_copy,gdr_copy,tcp"
+ UCX_IB_GID_INDEX: "3"
+ OMPI_MCA_pml: "ob1"
+volumes:
+ - /dev/infiniband (hostPath)
+```
+
+UCX 1.20+, IMEX channels via DRA, and `/dev/infiniband` are all required for cross-node GPU RDMA.
+
+---
+
+## verl Integration (In Progress)
+
+### What is verl?
+
+A production-grade RL framework by ByteDance that uses Ray for orchestration. Supports FSDP, Megatron, vLLM, SGLang, TRT-LLM. Has a `CheckpointEngine` plugin system for weight transfer.
+
+### Architecture
+
+```mermaid
+graph LR
+ subgraph cluster["Ray Cluster · GKE GB200"]
+ direction TB
+
+ subgraph driver["Driver · CPU"]
+ task["TaskRunner"]
+ mgr["CheckpointEngine Manager"]
+ end
+
+ subgraph trainer_wg["Trainer WorkerGroup · GPU"]
+ direction LR
+ tw0["Worker 0 FSDP2 + MX CE"]
+ tw1["Worker 1 FSDP2 + CE"]
+ tw2["Worker 2 FSDP2 + CE"]
+ tw3["Worker 3 FSDP2 + CE"]
+ end
+
+ subgraph rollout_wg["Rollout Replicas · GPU"]
+ direction LR
+ rw0["CE Worker 0"]
+ rw1["CE Worker 1"]
+ rw2["CE Worker 2"]
+ rw3["CE Worker 3"]
+ vs["vLLM Server"]
+ end
+
+ subgraph mx_svc["MX Server + Redis · CPU"]
+ mx["gRPC Metadata Broker"]
+ end
+
+ task --> mgr
+ mgr -. "ray.get" .-> tw0
+ mgr -. "ray.get" .-> rw0
+ tw0 -- "gRPC publish" --> mx
+ rw0 -- "gRPC discover" --> mx
+ tw0 <== "NIXL RDMA" ==> rw0
+ tw1 <== "NIXL RDMA" ==> rw1
+ tw2 <== "NIXL RDMA" ==> rw2
+ tw3 <== "NIXL RDMA" ==> rw3
+ rw0 -. "CUDA IPC" .-> vs
+ end
+
+ style cluster fill:#1a1a2e,stroke:#16213e,color:#e0e0e0
+ style driver fill:#1a1a2e,stroke:#533483,color:#e0e0e0
+ style trainer_wg fill:#0f3460,stroke:#533483,color:#e0e0e0
+ style rollout_wg fill:#0f3460,stroke:#533483,color:#e0e0e0
+ style mx_svc fill:#1a1a2e,stroke:#4caf50,color:#e0e0e0
+ style task fill:#533483,stroke:#e94560,color:#fff
+ style mgr fill:#1b5e20,stroke:#4caf50,color:#fff
+ style tw0 fill:#1b5e20,stroke:#4caf50,color:#fff
+ style tw1 fill:#162447,stroke:#533483,color:#e0e0e0
+ style tw2 fill:#162447,stroke:#533483,color:#e0e0e0
+ style tw3 fill:#162447,stroke:#533483,color:#e0e0e0
+ style rw0 fill:#1b5e20,stroke:#4caf50,color:#fff
+ style rw1 fill:#2e7d32,stroke:#66bb6a,color:#fff
+ style rw2 fill:#2e7d32,stroke:#66bb6a,color:#fff
+ style rw3 fill:#2e7d32,stroke:#66bb6a,color:#fff
+ style vs fill:#533483,stroke:#e94560,color:#fff
+ style mx fill:#1b5e20,stroke:#4caf50,color:#fff
+```
+
+Green = MX checkpoint engine components. Purple = existing verl/Ray components. The `CheckpointEngineManager` coordinates the MX CE workers on both trainer and rollout sides. Each trainer-rollout rank pair transfers via NIXL RDMA. Received weights reach vLLM via CUDA IPC through the existing `ServerAdapter`.
+
+### Why It's Easier Than PRIME-RL
+
+verl already has:
+- **`CheckpointEngine` ABC** with `send_weights` / `receive_weights` — just implement a new backend
+- **Existing NIXL engine** (`nixl_checkpoint_engine.py`) as a reference implementation
+- **Bucketed transfers** that preserve tensor names and shapes — no scratch-buffer approach needed
+- **`@CheckpointEngineRegistry.register("mx")`** — one decorator to plug in
+
+### Integration Summary
+
+| Item | Details |
+|------|---------|
+| **Repo** | `github.com/KavinKrishnan/verl`, branch `kavink/mx-checkpoint-engine` |
+| **New file** | `verl/checkpoint_engine/mx_checkpoint_engine.py` (461 lines) |
+| **Modified files** | 2 files, 8 lines (imports + config comment) |
+| **Total** | 469 new lines, 2 modified |
+
+### `MxCheckpointEngine` Design
+
+Registered as `@CheckpointEngineRegistry.register("mx")`. Implements the `CheckpointEngine` ABC:
+
+- **`prepare()`** — Allocate send/recv GPU buckets, register with NIXL, create MX client
+- **`build_topology()`** — Trainer rank 0 is the source, all rollout ranks connect via MX Server
+- **`init_process_group()`** — Trainer adds rollout agents; rollouts add trainer agent
+- **`send_weights()`** — Pack tensors into GPU buckets, send bucket metadata via ZMQ, make available for RDMA read
+- **`receive_weights()`** — Receive metadata via ZMQ, NIXL RDMA read from trainer's bucket, yield `(name, tensor)` pairs
+- **`finalize()`** — Cleanup connections, deregister memory
+
+Uses a **star topology** (trainer → all rollouts) instead of the NIXL engine's ring. The MX Server enables future pipeline replication where rollouts become sources.
+
+### Config
+
+```yaml
+actor_rollout_ref:
+ rollout:
+ checkpoint_engine:
+ backend: "mx"
+ engine_kwargs:
+ mx_server_url: "modelexpress-server:8001"
+ model_name: "Qwen/Qwen2.5-1.5B"
+```
+
+---
+
+## Comparison: PRIME-RL vs verl Integration
+
+| Aspect | PRIME-RL | verl |
+|--------|---------|------|
+| Plugin system | Custom broadcast/worker extensions | `CheckpointEngine` registry |
+| Existing NIXL support | None (built from scratch) | Full NIXL engine as reference |
+| Lines of code | ~350 new + 93 modified | ~460 new + 8 modified |
+| Ray dependency | None (raw K8s pods) | Ray actors, placement groups |
+| Weight format | HF names → scratch buffers → `load_weights()` | Bucketed with shapes preserved |
+| Tensor shape issue | Required safetensors header reading | Not an issue (bucket metadata carries shapes) |
+| Colocated mode | N/A (always disaggregated) | `naive` engine handles colocated; MX for disaggregated |
+
+---
+
+## Future: Pipeline Replication
+
+Current design uses a star topology (trainer → all rollouts). At scale, the trainer's NIC becomes a bottleneck. [TensorHub](https://arxiv.org/abs/2604.09107v1) (ByteDance, April 2026) demonstrates **pipeline replication**: after a rollout receives weights, it publishes itself as a source. New rollouts pull from the nearest/least-loaded replica, creating a bandwidth-amplifying DAG.
+
+MX Server already supports multiple sources per model — implementing pipeline replication is a client-side change:
+1. After RDMA receive, rollout calls `publish_weights()` on MX Server
+2. New rollouts call `poll_for_source()` which returns the nearest available replica
+3. MX Server load-balances across all replicas
+
+This is prioritized as P1 in our roadmap.
+
+---
+
+## Repository Map
+
+### ModelExpress client (`kavink/RL` branch)
+
+```text
+modelexpress_client/python/modelexpress/
+├── training_publisher.py # MxTrainingPublisher — trainer-side publish
+├── refit_receiver.py # MxRefitReceiver — inference-side RDMA receive
+├── nixl_transfer.py # NixlTransferManager — NIXL agent lifecycle
+├── client.py # MxClient — gRPC client to MX Server
+└── __init__.py # Exports MxTrainingPublisher, MxRefitReceiver
+```
+
+### PRIME-RL integration (`kavink/mx-weight-broadcast` branch)
+
+```text
+src/prime_rl/
+├── trainer/rl/broadcast/modelexpress.py # ModelExpressWeightBroadcast
+├── inference/vllm/worker/modelexpress.py # MxWeightUpdateWorker
+├── inference/vllm/server.py # /init_mx_broadcaster route (+13 lines)
+├── orchestrator/orchestrator.py # elif "modelexpress" branch (+7 lines)
+├── utils/client.py # init_mx_broadcast() (+34 lines)
+└── configs/ # MxWeightBroadcastConfig (+32 lines)
+```
+
+### verl integration (`kavink/mx-checkpoint-engine` branch)
+
+```text
+verl/
+├── checkpoint_engine/mx_checkpoint_engine.py # MxCheckpointEngine
+├── checkpoint_engine/__init__.py # Optional import (+7 lines)
+└── workers/config/rollout.py # "mx" in backend comment (+1 line)
+```
diff --git a/docs/RL/NIXL_COMPRESSION_STUDY.md b/docs/RL/NIXL_COMPRESSION_STUDY.md
new file mode 100644
index 00000000..e5df9cd4
--- /dev/null
+++ b/docs/RL/NIXL_COMPRESSION_STUDY.md
@@ -0,0 +1,303 @@
+# NIXL nvCOMP Compression Study — Reproducing with ModelExpress RL Workflows
+
+**Last Updated**: April 29, 2026
+**Audience**: NIXL compression team
+**Purpose**: Guide the NIXL team to capture and study real RL weight-transfer payloads using our validated PRIME-RL and verl workflows with ModelExpress (MX).
+
+---
+
+## Background
+
+The NIXL team is evaluating nvCOMP GPU compression on the tensors that flow through NIXL during RL post-training. There are two transfer types:
+
+1. **RL refit** (training → inference): full model weights, every RL step.
+2. **KV cache** (prefill → decode): per-request KV tensors in disaggregated inference.
+
+We have **two validated end-to-end RL workflows** that produce these payloads over NIXL on GB200:
+
+| Workflow | Framework | Status | PR | What it exercises |
+|----------|-----------|--------|-----|-------------------|
+| **PRIME-RL overlay** | PRIME-RL + vLLM | Scenarios A/B/C green on GB200 (20/20 steps each) | [PrimeIntellect-ai/prime-rl#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343) | NIXL RDMA weight push via PI's `NIXLWeightBroadcast` + `TransportPlan`, MX-mediated discovery |
+| **verl MxCheckpointEngine** | verl + vLLM | 10 steps green on GB200 | [ai-dynamo/modelexpress#252](https://github.com/ai-dynamo/modelexpress/pull/252) | NIXL RDMA weight transfer via `MxCheckpointEngine` (`CheckpointEngine` plugin) |
+
+Both produce the **exact same kind of data** the NIXL team requested: raw BF16 weight tensors flowing GPU-to-GPU over NIXL, plus pre/post RL-step weight deltas for delta-compression analysis.
+
+---
+
+## Component View
+
+Where the data being studied actually lives + what writes/reads it. Green is what the compression team would compress; purple is the existing RL stack producing it.
+
+```mermaid
+flowchart TB
+ subgraph trainer_node["Trainer node — GB200"]
+ direction TB
+ T_FSDP["FSDP2 trainer optimizer.step()"]
+ T_PUB["MxTrainingPublisher + NIXLWeightBroadcast"]
+ T_NIXL(["NIXL agent (UCX rc_mlx5)"])
+ T_FSDP --> T_PUB --> T_NIXL
+ end
+
+ subgraph mx_meta["Metadata plane — MX Server"]
+ MX["MX Server (gRPC)"]
+ REDIS[("Redis")]
+ MX --> REDIS
+ end
+
+ subgraph inference_node["Inference node — GB200"]
+ direction TB
+ I_NIXL(["NIXL agent (UCX rc_mlx5)"])
+ I_RECV["MxRefitReceiver + NIXLWeightUpdateWorker"]
+ VLLM["vLLM engine (live params)"]
+ I_NIXL --> I_RECV --> VLLM
+ end
+
+ subgraph prefill_node["Prefill worker — GB200 (disagg inference)"]
+ direction TB
+ P_FWD["vLLM prefill pass"]
+ P_KV["KV cache buffer (per-layer key/value)"]
+ P_FWD --> P_KV
+ end
+
+ subgraph decode_node["Decode worker — GB200"]
+ direction TB
+ D_KV["KV cache import"]
+ D_GEN["Token generation"]
+ D_KV --> D_GEN
+ end
+
+ T_PUB -. "publish metadata (SourceIdentity, agent blob)" .-> MX
+ MX -. "discover" .-> I_RECV
+
+ T_NIXL ==> |"① RL REFIT weights (BF16) ~3 GB / step (1.5B model) ~140 GB / step (70B)"| I_NIXL
+ P_KV ==> |"② KV CACHE tensors (BF16) ~14 MB at seq=512 ~3.5 GB at seq=131K"| D_KV
+
+ style T_FSDP fill:#533483,stroke:#e94560,color:#fff
+ style T_PUB fill:#533483,stroke:#e94560,color:#fff
+ style T_NIXL fill:#1b5e20,stroke:#4caf50,color:#fff
+ style I_NIXL fill:#1b5e20,stroke:#4caf50,color:#fff
+ style I_RECV fill:#533483,stroke:#e94560,color:#fff
+ style VLLM fill:#533483,stroke:#e94560,color:#fff
+ style P_FWD fill:#533483,stroke:#e94560,color:#fff
+ style P_KV fill:#1b5e20,stroke:#4caf50,color:#fff
+ style D_KV fill:#1b5e20,stroke:#4caf50,color:#fff
+ style D_GEN fill:#533483,stroke:#e94560,color:#fff
+ style MX fill:#533483,stroke:#e94560,color:#fff
+ style REDIS fill:#162447,stroke:#533483,color:#e0e0e0
+```
+
+**Compression target = the green edges.** ① is the RL-refit path between trainer and inference NIXL agents; ② is the KV cache transfer between prefill and decode. Everything purple — the trainer, the MX Server, vLLM, the receiver — is RL-stack infrastructure that wouldn't change if nvCOMP is added at the NIXL layer (compression would slot in transparently between `register` and `RDMA WRITE` on either edge).
+
+The capture scripts in [`scripts/`](./scripts/) snapshot the bytes that cross those green edges, plus pre/post weight tensors for delta-compression analysis.
+
+---
+
+## Option 1: Request the pre-captured data package (fastest)
+
+We have a ready-made data package captured from a live PRIME-RL deployment on GB200. **It's not in this repo** (binary tensors at GB scale aren't appropriate to commit) — request access from `kavink@nvidia.com` and we'll share via the appropriate channel (NV S3 bucket, internal share, or direct upload).
+
+Package contents:
+
+```text
+RL_Qwen25/
+├── model.safetensors # 2.9 GB — all 338 weight tensors (BF16)
+├── weights_pre_rl.safetensors # 3.4 GB — weights before optimizer.step()
+├── weights_post_rl.safetensors # 3.4 GB — weights after 1 AdamW step (lr=5e-6)
+├── weight_deltas.safetensors # 3.4 GB — elementwise diff (post - pre), BF16
+├── kv_cache/ # 14 MB — 56 KV tensors from a 501-token prefill
+│ ├── layer_0_key.bin # shape [1, 2, 501, 128], BF16
+│ ├── layer_0_value.bin
+│ ├── ...
+│ └── manifest.json # per-tensor metadata
+├── manifest.json # 66 KB — per-weight-tensor metadata
+└── README.md # full layout + compression properties
+```
+
+**Model**: Qwen2.5-1.5B BF16, 28 layers, 1.54B parameters. ~14 GB total package size.
+
+**How to read**:
+
+```python
+from safetensors import safe_open
+import torch
+
+# Weights (the exact tensors NIXL transfers during RL refit)
+with safe_open("model.safetensors", framework="pt") as f:
+ for key in f.keys():
+ tensor = f.get_tensor(key) # torch.bfloat16
+ raw_bytes = tensor.contiguous().untyped_storage() # raw bytes as on the wire
+ print(f"{key}: {tensor.shape}, {len(raw_bytes)} bytes")
+
+# Weight delta (for delta-compression analysis — compute in FP32 for precision)
+pre = safe_open("weights_pre_rl.safetensors", framework="pt")
+post = safe_open("weights_post_rl.safetensors", framework="pt")
+for key in pre.keys():
+ delta = post.get_tensor(key).float() - pre.get_tensor(key).float()
+ print(f"{key}: max_abs_delta={delta.abs().max():.2e}")
+
+# KV cache (the exact tensors transferred prefill → decode via NIXL)
+raw = open("kv_cache/layer_0_key.bin", "rb").read()
+kv = torch.frombuffer(bytearray(raw), dtype=torch.bfloat16).reshape(1, 2, 501, 128)
+```
+
+**Key finding on deltas**: At BF16 precision, single-step RL deltas are mostly zero (AdamW updates at lr=5e-6 are below BF16's representable precision). For meaningful delta analysis, compute diffs in FP32. This suggests delta-compression should operate in FP32 and quantize back after.
+
+---
+
+## Option 2: Reproduce end-to-end on GB200 (PRIME-RL overlay)
+
+Run our validated PRIME-RL overlay workflow and capture weights mid-flight using the published [`scripts/`](./scripts/) directory.
+
+### Prerequisites
+
+- A GB200 cluster (ARM64) with at least 2 nodes, container runtime, and an RDMA-capable interconnect (InfiniBand or RoCE) between nodes. Cluster orchestration is Kubernetes-based; manifests assume `kubectl` access and a working namespace.
+- A namespace where you'll deploy the overlay, with the following bound:
+ - MX Server reachable at `modelexpress-server..svc.cluster.local:8001` (Helm chart in this repo, or use an existing deployment)
+ - Redis backing the MX Server
+ - A shared model-cache PVC for HuggingFace downloads
+ - An image pull secret for the registry hosting the overlay image (we publish to `nvcr.io/nvidian/dynamo-dev/`; you can also build locally)
+
+The included K8s manifests under `prime-rl/k8s/prime-rl-mx-on-nixl/` may need light edits (namespace, node selectors, image pull secret name, RDMA network annotations) for your cluster — they're examples, not portable across all GB200 deployments. The capture flow itself is cluster-agnostic.
+
+### Step 1: Deploy the PRIME-RL overlay
+
+```bash
+git clone git@github.com:KavinKrishnan/prime-rl.git
+cd prime-rl
+git checkout kavink/mx-on-nixl
+
+# Use the pre-built ARM64 image, or build locally
+# Pre-built: nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.2
+docker buildx build --platform linux/arm64 \
+ -f docker/Dockerfile.mx-on-nixl \
+ -t /prime-rl-mx-on-nixl:v0.2 \
+ --push .
+
+# Edit k8s/prime-rl-mx-on-nixl/*.yaml for your namespace, node selectors,
+# RDMA network annotations, and image registry. Then:
+cd k8s/prime-rl-mx-on-nixl
+./run.sh deploy A # scenario A = PI's NIXL transport, no MX env vars
+./run.sh status # wait until all 3 pods are Running
+```
+
+### Step 2: Verify the RL loop is running
+
+```bash
+NS=
+kubectl -n $NS logs prime-rl-mx-on-nixl-trainer-0 --tail=20 | grep "SUCCESS.*Step"
+kubectl -n $NS logs prime-rl-mx-on-nixl-inference-0 | grep "update_weights.*200"
+```
+
+### Step 3: Capture using the published script
+
+We ship `capture_on_pod.py` in [`scripts/`](./scripts/) — same script that produced our pre-captured Qwen2.5-1.5B package. It captures pre/post RL weights, simulates one AdamW step, computes deltas, and dumps a KV cache prefill, all in one pass.
+
+```bash
+NS=
+
+# Copy the script into the trainer pod
+kubectl cp docs/RL/scripts/capture_on_pod.py \
+ $NS/prime-rl-mx-on-nixl-trainer-0:/tmp/capture.py
+
+# Run it inside the pod (overlay image's interpreter is /app/.venv/bin/python)
+kubectl exec $NS/prime-rl-mx-on-nixl-trainer-0 -- /app/.venv/bin/python /tmp/capture.py \
+ --model Qwen/Qwen2.5-1.5B \
+ --out /tmp/nixl_capture \
+ --kv-seq-len 512 \
+ --lr 5e-6
+
+# Copy the results back
+kubectl cp $NS/prime-rl-mx-on-nixl-trainer-0:/tmp/nixl_capture ./RL_capture
+```
+
+Output `RL_capture/` contains four sub-directories (`weights_pre_rl/`, `weights_post_rl/`, `weight_deltas/`, `kv_cache/`) each with raw `.bin` files plus a `manifest.json`. See [`scripts/README.md`](./scripts/README.md) for the full layout + flag reference.
+
+### Step 4 (optional): Capture without a running RL deployment
+
+If reproducing the overlay is more cluster work than the data is worth, [`scripts/capture_weights_and_kv.py`](./scripts/capture_weights_and_kv.py) is the **standalone** variant — works on any host (CPU or single GPU), no Kubernetes / RL framework required:
+
+```bash
+pip install torch transformers safetensors
+
+python docs/RL/scripts/capture_weights_and_kv.py \
+ --model Qwen/Qwen2.5-1.5B \
+ --output-dir ./nixl_data \
+ --dtype bfloat16 \
+ --device cpu
+```
+
+Doesn't simulate an RL step (no pre/post/delta), but produces the same weight + KV cache layout the NIXL team can compress against.
+
+### Step 5: Tear down
+
+```bash
+./run.sh clean
+```
+
+---
+
+## Option 3: Reproduce with verl MxCheckpointEngine
+
+The verl integration uses the same MX client but through verl's `CheckpointEngine` plugin. This path captures the weights as they flow through `MxCheckpointEngine.send_weights()` / `receive_weights()`.
+
+Deployment docs: `docs/RL/VERL_MX_OVERVIEW.md` §6 in the modelexpress repo.
+
+The capture approach is the same as Option 2 (exec into the trainer pod, save state dict pre/post step) since the weight tensors are identical — both frameworks produce `model.named_parameters()` in BF16. The difference is the transport path (verl's bucket+ZMQ metadata vs prime-rl's TransportPlan+slot system), which doesn't affect the tensor content.
+
+---
+
+## What to capture for the compression study
+
+| Artifact | File | Size (Qwen3-0.6B) | What it represents |
+|----------|------|-------|---------------------|
+| **Current weights** | `weights_current.safetensors` | ~1.2 GB | Exact tensors registered with NIXL and RDMA-written to inference GPU every RL step |
+| **Post-step weights** | `weights_post_step.safetensors` | ~1.2 GB | After one AdamW step (lr=5e-6) |
+| **Weight deltas** | `weight_deltas.safetensors` | ~1.2 GB | `post - pre` in BF16 (mostly zero — compute in FP32 for real deltas) |
+| **KV cache** | `kv_cache/*.bin` | ~14 MB | Prefill output transferred to decode workers via NIXL |
+| **Manifest** | `manifest.json` | ~30 KB | Per-tensor: name, shape, dtype, size_bytes |
+
+### Larger models for more representative data
+
+The steps above use Qwen3-0.6B (our scenario A model). For larger models closer to production:
+
+| Model | Params | Weight payload | Notes |
+|-------|--------|----------------|-------|
+| Qwen3-0.6B (above) | 0.6B | ~1.2 GB | Validated in PR #2343 scenarios A/B/C |
+| Qwen2.5-1.5B | 1.5B | ~3 GB | Pre-captured package available on request (see Option 1) |
+| Qwen2.5-7B | 7.6B | ~15 GB | T1 model in our overlay plan |
+| Qwen3-MoE (PI offered spec) | MoE | varies | Would exercise `ExpertSlot` + per-expert tensors — most representative for MoE compression |
+
+For models requiring multiple GPUs, the weights are FSDP-sharded — each rank's shard is `total / num_ranks` in size. The bytes on the wire per-rank are the shard size, not the full model.
+
+---
+
+## Compression-relevant properties
+
+| Property | Weights | KV Cache | Delta (FP32) |
+|----------|---------|----------|--------------|
+| **Dtype on wire** | BF16 (2 B/elem) | BF16 (2 B/elem) | BF16 stored, but FP32 is the meaningful analysis dtype |
+| **Value distribution** | Normal, centered ~0, std 0.01–0.1 | Wider, context-dependent | Very small magnitude (~1e-8 to 1e-6 per element) |
+| **Sparsity** | Dense (no zeros) | Dense | ~100% zero at BF16 precision; structured-sparse at FP32 |
+| **Best compression angle** | Entropy coding on mantissa bits | Temporal locality across layers | FP32 delta + entropy coding — high compressibility expected |
+| **Transfer frequency** | Every RL step (~5–60 s) | Every request | Once for analysis |
+| **Bucket size on wire** | 596 MB (measured in scenario A/B/C) | per-request, scales with seq_len | N/A |
+
+### NIXL integration point for nvCOMP
+
+If nvCOMP compression is added at the NIXL layer, the integration is transparent to both MX and the RL frameworks:
+
+```text
+Current:
+ Training GPU → NIXL register → RDMA WRITE (raw bytes) → Inference GPU
+
+With NIXL-layer nvCOMP:
+ Training GPU → NIXL register → nvCOMP compress (GPU) → RDMA WRITE (compressed) → nvCOMP decompress (GPU) → Inference GPU
+```
+
+No changes to `MxTrainingPublisher`, `MxRefitReceiver`, `NIXLWeightBroadcast`, `TransportPlan`, or the MX Server protocol. Compression is internal to NIXL's transfer path. Our bucket-streaming pattern is preserved — compression happens per-bucket.
+
+---
+
+## Questions?
+
+Reach out to Kavin Krishnan (`kavink@nvidia.com`) for access to the pre-captured data or help reproducing on a cluster. The PRIME-RL overlay branch (`KavinKrishnan/prime-rl:kavink/mx-on-nixl`) and the modelexpress RL branch (`ai-dynamo/modelexpress:kavink/RL`) are the entry points.
diff --git a/docs/RL/PRIMERL_MX_NATIVE_DESIGN.md b/docs/RL/PRIMERL_MX_NATIVE_DESIGN.md
new file mode 100644
index 00000000..540a272e
--- /dev/null
+++ b/docs/RL/PRIMERL_MX_NATIVE_DESIGN.md
@@ -0,0 +1,276 @@
+# PRIME-RL × ModelExpress — Native API Design (Path B)
+
+**Status**: Design proposal (no code yet)
+**Last Updated**: April 2026
+**Companion to**: `PRIMERL_MX_OVERVIEW.md` (the overlay-on-PI design, "Path A")
+
+This document describes a **native MX-shaped weight broadcast backend** for PRIME-RL that uses [PI's NIXL transport](https://github.com/PrimeIntellect-ai/prime-rl/pull/2326) as the bytes-on-wire data plane, but exposes **ModelExpress's traditional API surface** to PRIME-RL — model-agnostic, server-mediated, scratch-buffer-default, cross-framework.
+
+It's the alternative to "Path A" (strict overlay on PI's API). Path A is in flight as draft PR [PrimeIntellect-ai/prime-rl#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343); Path B is staged here as the design we'd advocate for if Path A hits friction or if the team wants the broader strategic value.
+
+---
+
+## 1. Why This Doc Exists
+
+Path A's overlay strategy preserves PI's NIXL API surface end-to-end and replaces only the SPG rendezvous with MX Server discovery. It's small, cooperative, easy to merge — and it inherits every PI design constraint:
+
+- **Per-model `conversion_specs()` requirement.** PI's `TransportPlan` only supports models that PI has authored a spec table for: today, just `glm_moe_dsa`. Plain Qwen3, Llama, Mixtral, anything else needs a spec table written. We just discovered this when scenario A's trainer crashed with `'FSDPQwen3ForCausalLM' object has no attribute 'conversion_specs'` and had to monkey-patch HF Qwen3 to unblock.
+- **Direct refit only — KL drift class of bugs is in scope.** PI's PR is currently blocked by 27+ iterations of KL drift investigation. Their byte-exact transport is correct; the drift comes from concurrent UCX writes into live vLLM `param.data`. Inheriting their target-buffer model means inheriting that bug surface.
+- **Static, startup-time tensor registration.** `Slot{Sharded,Gathered,Expert}` assume fixed tensor shapes registered with NIXL once at init. Elastic workloads (LoRA-RL with dynamic adapter add/remove, frozen-policy-adapts variants, growing context-len rollouts) don't fit this model cleanly.
+- **prime-rl-only.** PI's `NIXLWeightBroadcast` lives in `prime_rl/trainer/rl/broadcast/nixl.py`; cross-framework portability would require us to copy the design into verl + future NeMo-RL.
+
+Path B is the "what if we redesigned the prime-rl-side weight broadcast around MX's shape, but kept PI's UCX/NIXL setup as-is for the bytes" answer. The data plane is identical (same RDMA bytes on wire, same per-NIC bandwidth, same `rc_mlx5` transport, same per-rank NIC pin). Only the control plane and tensor ABI change.
+
+---
+
+## 2. What MX-Traditional Looks Like (Recap)
+
+ModelExpress has a consistent API across the integrations we've shipped (verl `MxCheckpointEngine`, our existing PRIME-RL `ModelExpressWeightBroadcast` on `kavink/mx-weight-broadcast`):
+
+```python
+# Trainer side
+publisher = MxTrainingPublisher(
+ agent_name="trainer-rank-0",
+ device_id=local_rank,
+ server_url="modelexpress-server.kavin.svc.cluster.local:8001",
+)
+publisher.initialize() # NIXL agent up, gRPC connected
+publisher.publish_weights(state_dict, step=N) # per training step
+# ...later, before optimizer.step() reuses buffers:
+publisher.unpublish(version=N) # mutability contract
+
+
+# Inference side
+receiver = MxRefitReceiver(
+ model_name="...",
+ worker_rank=k,
+ device_id=local_rank,
+)
+receiver.initialize(model_tensors=scratch_or_live_dict) # NIXL register receive buffers
+source = receiver.poll_for_source(min_version=N) # discover trainer
+receiver.receive_weights(source) # RDMA pull (or accept WRITE)
+# scratch path: vllm_model.load_weights(scratch_iter)
+# direct path: receive lands directly in vllm_model.param.data
+```
+
+Salient differences from PI's design:
+
+| Concern | PI (Path A overlay inherits) | MX-traditional (Path B exposes) |
+|---------|------------------------------|--------------------------------|
+| Discovery | SPG static rendezvous | gRPC `MxClient` (publish, list_sources, get_metadata) |
+| Source identity | Implicit rank pairing | Content-addressed `mx_source_id = sha256(SourceIdentity)` |
+| Per-model contract | `model.conversion_specs(layer_idx)` | None — model-agnostic; publishes live `state_dict` tuples |
+| Tensor ABI | Fixed `Slot{...}` registered at startup | Per-step `(name, shape, dtype, gpu_addr)` published; receiver re-registers if shape drifts |
+| Receive target | Direct WRITE into live `param.data` | **Scratch buffer + `model.load_weights()`** by default; opt-in direct refit |
+| Quantization | First-class `ConversionSpec`/`QuantizationSpec` | Trainer pre-quantizes if needed; MX is dtype-agnostic |
+| Lifecycle | rendezvous → register × N → write × N per startup | publish/poll/receive/unpublish per step; no static startup registration |
+| Versioning | Implicit step counter | First-class `extra_parameters.training_step` |
+| Mutability contract | Implicit (root cause of KL drift?) | Explicit `unpublish()` with drain |
+| Cross-framework | prime-rl only | Same client runs in verl, prime-rl, future NeMo-RL |
+| vLLM integration | Worker extension only | Worker extension OR `WeightTransferEngine` plugin (Step 11) |
+
+---
+
+## 3. Path B Architecture
+
+Same NIXL data plane as PI; different prime-rl-side abstractions on top of it.
+
+```text
+┌───────────────────────────────────────────────────────────────────┐
+│ prime-rl trainer process │
+│ ┌─────────────────────────────────────────────────────────────┐ │
+│ │ MxWeightBroadcast (new, in prime_rl/trainer/rl/broadcast/) │ │
+│ │ ─ implements PI's WeightBroadcast ABC │ │
+│ │ ─ delegates the data path to MxTrainingPublisher │ │
+│ │ ─ delegates the control path to MxClient gRPC │ │
+│ └─────────────────────────────────────────────────────────────┘ │
+│ │ │ │
+│ ▼ data ▼ metadata │
+│ MxTrainingPublisher MxClient(server_url=...) │
+│ (modelexpress) (modelexpress) │
+│ │ │ │
+│ ▼ ▼ │
+│ NixlAgentWrapper ┌─────────────────────────────┐ │
+│ (PI's existing class) │ MX Server (gRPC + Redis) │ │
+│ │ │ ─ source registry │ │
+│ ▼ post_write │ ─ poll_for_source │ │
+│ UCX rc_mlx5 RDMA │ ─ pipeline replication │ │
+│ (same wire as PI) │ ─ retention + versioning │ │
+│ │ └─────────────────────────────┘ │
+│ ▼ │
+│ ConnectX-7 NIC ─── RoCE ───► rollout NIC │
+│ │
+└───────────────────────────────────────────────────────────────────┘
+
+┌───────────────────────────────────────────────────────────────────┐
+│ prime-rl inference (vLLM) worker │
+│ ┌─────────────────────────────────────────────────────────────┐ │
+│ │ MxWeightUpdateWorker (new) │ │
+│ │ ─ vLLM worker extension │ │
+│ │ ─ MxRefitReceiver delegate │ │
+│ │ ─ default: scratch buffer + model.load_weights() │ │
+│ │ ─ opt-in: direct refit into live param.data │ │
+│ └─────────────────────────────────────────────────────────────┘ │
+│ │ │
+│ ▼ │
+│ MxRefitReceiver (modelexpress) │
+│ │ │
+│ ▼ │
+│ NixlAgentWrapper (PI's existing class) │
+│ │ │
+│ ▼ │
+│ accepts WRITE from trainer (same as PI) │
+└───────────────────────────────────────────────────────────────────┘
+```
+
+### What we adopt unchanged from PI's PR
+
+These are real engineering wins; we'd be foolish to redo them:
+
+- `NixlAgentWrapper` (UCX agent setup, register_tensor, prep_local/prep_remote, post_write, wait, drain)
+- `pin_ucx_rail` (per-rank NIC pinning that's the difference between 4.8 GB/s and 7.5 GB/s)
+- `classic_cuda_pool` (allocator workaround for `expandable_segments` + `ibv_reg_mr` "local protection" bug)
+- The runtime image stack (UCX 1.19, NIXL 0.10.1, ARM64 quirks)
+- Pre-write SPG barrier / quiescence pattern (the iter15 fix; we'd implement equivalent via MX Server fence RPC)
+- HSDP primary-replica gate (only `dp_replicate == 0` runs the protocol)
+
+These come into MX as either direct re-export or via a `prime_rl.utils` import; no need to fork.
+
+### What we replace with MX-shape
+
+| PI component | Our replacement |
+|--------------|-----------------|
+| `NIXLWeightBroadcast` | `MxWeightBroadcast` — same prime-rl ABC (`WeightBroadcast`), different internals |
+| `NIXLWeightUpdateWorker` | `MxWeightUpdateWorker` — same vLLM worker_extension_cls slot |
+| `TransportPlan` | Replaced. Per-step iteration over `state_dict` tuples instead of slot table |
+| `model.conversion_specs(layer_idx)` | Removed entirely. Trainer publishes whatever's in `state_dict`; vLLM's `model.load_weights()` does the HF→kernel format on inference (its tested code path) |
+| `Slot{Sharded,Gathered,Expert}` | Removed. The per-rank publishing semantics come from `MxTrainingPublisher`'s `worker_rank` field; tiny-tensor coalescing is a NIXL register optimization not a slot type |
+| SPG rendezvous (`StatelessProcessGroup`) | Removed. Discovery is `MxClient.poll_for_source`; per-step barrier is `MxClient.fence` (new) or fall back to a small SPG over MX-discovered endpoints |
+| `ConversionSpec`/`QuantizationSpec` (FP8 quantize on trainer) | Optional on trainer side. Adopted as `MxQuantizer` if user wants FP8 fast path; default is BF16 passthrough. Inference-side decompresses via vLLM's existing FP8 loader, not via NIXL post-processing. |
+
+### What we add that PI doesn't have
+
+These are MX traditional features that translate naturally into prime-rl now that we're not constrained by PI's slot abstraction:
+
+- **Pipeline replication** (TensorHub-style DAG): rollout publishes itself as a secondary source after receive; MX Server load-balances new pollers across trainer + rollouts.
+- **Peer recovery**: a restarting rollout pod pulls from any surviving peer (via `poll_for_sources` ranked by health/locality), not always from trainer.
+- **Versioning + retention**: keep-latest-N versions on MX Server; rollouts can request a specific version or "latest stable."
+- **Mutability contract**: explicit `unpublish(version)` before trainer reuses slot buffers; server blocks until in-flight pulls drain. Directly addresses PI's KL drift hypothesis (write ordering / live param visibility).
+- **Cross-framework**: same `MxTrainingPublisher`/`MxRefitReceiver` already proven in verl and our existing PRIME-RL POC. Path B is the alignment: prime-rl + verl + future NeMo-RL share the same MX abstractions.
+- **Scratch-buffer default**: receive lands in isolated GPU tensors; vLLM `model.load_weights()` applies them via its tested NCCL-equivalent path. KL-drift class of bugs falls away. Direct refit becomes opt-in for users who measure and accept the correctness risk.
+- **Elastic shapes**: per-step `(name, shape, dtype)` publishing means LoRA-RL, dynamic adapters, growing context lengths all work without re-init.
+
+---
+
+## 4. Concrete File Footprint
+
+What changes in `KavinKrishnan/prime-rl:kavink/mx-on-nixl` to ship Path B (relative to current Path A overlay):
+
+### New files
+
+| File | Purpose | Est. LOC |
+|------|---------|----------|
+| `src/prime_rl/trainer/rl/broadcast/mx.py` | `MxWeightBroadcast` — new prime-rl `WeightBroadcast` impl. ~3 methods: `__init__` (publisher + agent setup), `broadcast_weights(model, step)` (publish state_dict tuples, signal orchestrator), `shutdown()` | ~300 |
+| `src/prime_rl/inference/vllm/worker/mx.py` | `MxWeightUpdateWorker` — vLLM worker_extension_cls. Delegates to `MxRefitReceiver` for receive, then `model.load_weights(scratch_iter)` for apply. Direct-refit opt-in via env. | ~250 |
+| `src/prime_rl/configs/...` | New `MxWeightBroadcastConfig` discriminator on `WeightBroadcastConfig` union. `type: "mx"` | ~60 |
+| `docs/weight-transfer-modelexpress.md` | Already requested by `@mikasenghaas` in #2326 review. Documents all four backends (filesystem, nccl, nixl, mx) with selection guidance | ~250 |
+
+### Modified files
+
+| File | Change | Est. LOC |
+|------|--------|----------|
+| `src/prime_rl/configs/trainer.py`, `orchestrator.py`, `rl.py` | Add `MxWeightBroadcastConfig` to discriminated union; thread through unify-mode for the `rl` entrypoint | +40 |
+| `src/prime_rl/trainer/rl/broadcast/__init__.py` | Add `mx` dispatch in `setup_weight_broadcast()` | +15 |
+| `src/prime_rl/inference/vllm/server.py` (or where worker_extension_cls is wired) | `mx` value selects `MxWeightUpdateWorker` | +10 |
+| `src/prime_rl/utils/client.py` (TRANSFER_READY marker) | Touch for `mx` backend too (already protocol-agnostic per the existing review feedback) | +5 |
+
+### Files we don't touch
+
+PI's NIXL backend stays in place. Users who want PI's path get `type: "nixl"`; users who want ours get `type: "mx"`. No conflict between the two backends — they coexist as discriminator options.
+
+### What MX-side code we add (in `ai-dynamo/modelexpress`)
+
+Mostly already exists from the verl + existing PRIME-RL POCs. Net new:
+
+- `MxTrainingPublisher.publish_weights_via_nixl(state_dict, step, agent)` — adapt our existing publisher to use PI's `NixlAgentWrapper` directly (same NIXL bytes as PI, just driven from our publisher class).
+- `MxRefitReceiver.receive_weights_via_nixl(source, agent, target)` — same.
+- `MxClient.fence(model, version, world_size)` — server-mediated barrier (replaces SPG barrier per step). Optional; can fall back to SPG over MX-discovered endpoints if we want to minimize server changes.
+
+Plus the server-side capabilities tracked in `PRIMERL_MX_OVERVIEW.md` §3 (pipeline replication index, retention, peer recovery preference ordering).
+
+---
+
+## 5. Migration Path for Users
+
+Users currently on PI's `type: "nixl"`:
+
+```toml
+# Before (PI's path)
+[weight_broadcast]
+type = "nixl"
+host = "..."
+port = 29502
+inference_world_size = N
+backends = ["UCX"]
+
+# After (MX-native path), no rebuild required
+[weight_broadcast]
+type = "mx"
+mx_server_url = "modelexpress-server.kavin.svc.cluster.local:8001"
+model_name = "..."
+inference_world_size = N
+# transfer_mode = "scratch" # default; opt-in "direct" for the speed/correctness tradeoff
+# pipeline_replication = false # default; opt-in for fan-out scale
+```
+
+Same image, same UCX setup, same NIXL transport — just a different config discriminator and ~600 LOC of new code in prime-rl. PI's existing `nixl` backend stays available.
+
+For ourselves: we delete the monkey-patch on Qwen3 (`qwen3_specs_patch.py`) — it's no longer needed because Path B doesn't require per-model conversion specs. Path A's overlay survives as-is for users who want to stick with PI's transport API exactly.
+
+---
+
+## 6. Pitch Sequence for the Design Conversation
+
+If we propose Path B to PI on the existing draft PR (or in a new sibling PR):
+
+1. **Acknowledge PI's transport win directly.** "Your NIXL transport is excellent — UCX setup, classic_cuda_pool, pin_ucx_rail, FP8 ConversionSpec are all correct. Path B keeps every byte of that."
+
+2. **Frame the divergence as scope.** "We've been running ModelExpress as a metadata + elasticity layer across verl and our internal PRIME-RL POC for several months. The MX-shape API is model-agnostic, server-mediated, scratch-buffer-default — it solves problems your `Slot`/`ConversionSpec` design doesn't address (cross-framework, elastic shapes, KL drift via scratch path, retention, pipeline replication)."
+
+3. **Show the demo evidence.** "Path A overlay (already up as #2343) proved the metadata layer works on top of your transport. Path B extends that with a native MX API — here's the design doc, here's a draft diff."
+
+4. **Make the cohabitation explicit.** "Path B doesn't replace `type: nixl`. It adds `type: mx` as a sibling discriminator. Users pick. PI's GLM-5 production keeps using `nixl`; users who want a model-agnostic / cross-framework / scratch-default path use `mx`. Same UCX runtime image, same NIXL bytes, same `pin_ucx_rail` discipline — different control plane."
+
+5. **Address the KL drift directly.** "The drift you're chasing in iter22-27 is consistent with concurrent live-param writes. The MX scratch-buffer path isn't subject to that bug surface because writes land in isolated tensors, then `model.load_weights()` applies them via vLLM's NCCL-equivalent code path. Happy to A/B this on your iter26/27 config — same NIXL bytes, just a different target buffer. If your drift disappears in scratch mode, that's diagnostic data even if you keep `nixl` as default."
+
+---
+
+## 7. Inflection Points to Pivot A → B
+
+- **PI's PR stalls on KL drift > 2 weeks** without a root-cause fix landing. Path B's scratch-buffer default ships independent of that investigation.
+- **Reviewer pushback on Path A's overlay shape** — env-var-as-config or monkey-patching Qwen3 specs gets pushback that suggests a cleaner design is wanted.
+- **Need for elastic / LoRA-RL features** that PI's slot system can't accommodate. Path B's per-step publish handles dynamic shapes naturally.
+- **Cross-framework alignment becomes a strategic priority** — leadership wants "one weight broadcast story across prime-rl + verl + future frameworks" rather than per-framework redesigns.
+- **Scenario A's first NIXL run reveals the same KL drift on Qwen3** as PI saw on GLM-5. That's a strong empirical signal that direct-refit-into-live-params is the bug surface, not GLM-specific. Path B's scratch-default becomes the obvious fix.
+
+---
+
+## 8. Decisions Pending
+
+| Decision | Default if not made |
+|----------|---------------------|
+| Ship Path B alongside Path A or as a follow-up PR? | Follow-up PR; Path A goes first |
+| MX-mediated barrier (`MxClient.fence`) or SPG-over-MX-discovered endpoints? | SPG-over-MX-discovered for v0.1 (smaller server change) |
+| `MxWeightBroadcast` lives in `prime_rl/trainer/rl/broadcast/mx.py` (in-tree) or as a plugin from `modelexpress` package? | In-tree mirroring PI's nixl.py pattern |
+| Scratch vs direct refit default | Scratch (correctness-safe). Direct opt-in via `transfer_mode: direct` |
+| Pipeline replication default | Off. Opt-in via `pipeline_replication: true` |
+
+---
+
+## 9. Status
+
+- Path A draft PR: [PrimeIntellect-ai/prime-rl#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343) — Qwen3 conversion_spec patch in flight, image rebuild done, deploy pending tsh refresh.
+- Path B: design only (this doc). No code yet. Ready to author when an inflection point above triggers.
+- Tracking: `recovery/reinforcement learning/PRIME_INTELLECT_PR2326_Analysis.md` for the strategic comparison; `PRIMERL_MX_OVERVIEW.md` for the Path A overlay design.
+
+This doc is the artifact we'd reference in the conversation with PI if Path A doesn't carry the day on its own.
diff --git a/docs/RL/PRIMERL_MX_OVERVIEW.md b/docs/RL/PRIMERL_MX_OVERVIEW.md
new file mode 100644
index 00000000..c9d6cacd
--- /dev/null
+++ b/docs/RL/PRIMERL_MX_OVERVIEW.md
@@ -0,0 +1,839 @@
+# ModelExpress × PRIME-RL — Design Overview
+
+**Last Updated**: April 24, 2026
+**Status**: **Scenarios A, B, and C all green on GB200** — overlay validated end-to-end against PI's `nixl-weight-transfer` branch ([#2326](https://github.com/PrimeIntellect-ai/prime-rl/pull/2326)). All three scenarios completed 20/20 RL training steps with real NIXL RDMA pushes. Measured per-rank wire bandwidth **7.82–8.84 GB/s** (596 MB / ~80 ms per push, 310 slots), aggregate net BW 35–39 GB/s — exceeds PI's reported ~7.5 GB/s wire target. Scenario C also produced the pipeline-replication catalog entry (`rollout-source-0-*`), confirming the MX-side protocol works end-to-end. Scenarios D (scratch-buffer diagnostic) and E (peer recovery) are designed but deferred — they're follow-up axes (KL-drift triangulation, fault tolerance) not on the critical path for the overlay PR. Draft PR: [#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343).
+
+This document covers how ModelExpress (MX) plugs into [PRIME-RL](https://github.com/PrimeIntellect-ai/prime-rl)'s NIXL weight-transfer path as a **metadata and elasticity layer on top of** the existing `NIXLWeightBroadcast` / `TransportPlan` introduced by PR #2326. We do not reimplement their transport. We replace the SPG (StatelessProcessGroup) rendezvous with an MX-Server-mediated discovery plane, add pipeline replication, add a mutability contract, and enable a scratch-buffer diagnostic mode — all opt-in behind a single config flag.
+
+---
+
+## 1. Design Overview
+
+### What MX adds to PRIME-RL's NIXL backend
+
+PR #2326 gives PRIME-RL a bit-exact RDMA weight transport built on NIXL/UCX over RoCE, with slots (`ShardedSlot` / `GatheredSlot` / `ExpertSlot`), model-agnostic `ConversionSpec` / `QuantizationSpec`, FP8 trainer-side quantization, HSDP primary-replica push, per-rank NIC pinning, and an `expandable_segments`-safe `CUDAPluggableAllocator` slot pool. The transport works. What it doesn't have is a dynamic discovery plane.
+
+| Layer | Role in PR #2326 | Role with MX overlay |
+|-------|------------------|----------------------|
+| Data plane | NIXL RDMA (UCX / `rc_mlx5` / RoCE) | **Unchanged** — identical bytes on the wire |
+| Slot / bucket system | `ShardedSlot`, `GatheredSlot`, `ExpertSlot`, `TransportPlan` | **Unchanged** — imported as-is |
+| Publishing topology | **Per-rank sharding-aware** — each trainer rank publishes its own FSDP / TP / EP shard; no rank-0 allgather | **Unchanged** — this is a core property of PI's foundation, inherited by the overlay (see §3.9) |
+| Quantization / conversion | `ConversionSpec`, `QuantizationSpec` | **Unchanged** — same trainer-side FP8 path |
+| Rendezvous / discovery | SPG — static, rank-paired, global-world-size fixed at init | **Replaced by MX Server** (gRPC + Redis) when `rendezvous: mx_server` is set |
+| Topology | Star (trainer rank k → inference rank k, 1:1, no fan-in to rank 0) — trainer NIC is the single source for all fan-out | **Dynamic DAG** — trainer seeds the first rollout; each finished rollout becomes an additional source; MX Server load-balances new pollers across the growing source set. Same TensorHub pattern. Trainer NIC stops being a bottleneck once any rollout has received (§3.2) |
+| Mutability contract | None | **Explicit `publish` / `unpublish`** — trainer publisher marks slots immutable during rollout pulls, mutable before `optimizer.step()` |
+| Elastic topology | No — SPG locks `dp_shard×cp + inference_ws` at boot | **Yes** — rollouts can join / leave mid-run via `poll_for_source` |
+| Retention | None — no version history | **Keep-latest-N** — MX Server reaper preserves designated versions, CPU-offloads the last GPU copy if necessary |
+| Cross-framework | prime-rl only | **Same MX client** also powers verl `MxCheckpointEngine`, future NeMo-RL |
+| Expert-aware source tracking (MoE) | Implicit (ExpertSlot in client; no server-level index) | **Explicit server-side `(model, version, expert_id) → worker` index** + `poll_for_expert_source` RPC. Low-hanging win — primitives already exist in MX Server, overlay wires them up (§3.7) |
+| Peer recovery on pod restart | Not available — recovering rank must re-pull from trainer | **Multi-source discovery** — `poll_for_sources` returns ranked live peers holding the current version of rank k's shard; recovering rank pulls from nearest/least-loaded peer. Uses the same source index pipeline replication writes to. No event log / no version replay (§3.10) |
+| Scratch-buffer diagnostic | Not supported (direct refit only) | **Opt-in via `transfer_mode: scratch`** — uses PI's same transport but stages into isolated GPU tensors + `model.load_weights()` for KL-drift triangulation |
+
+### Component diagram (vertical, document-friendly)
+
+```mermaid
+graph TB
+ subgraph driver["Driver · CPU · orchestrator process"]
+ orch["RL Orchestrator (PI, unchanged)"]
+ httpapi["/pause /resume /update_weights (vLLM WeightTransferEngine endpoints)"]
+ orch --> httpapi
+ end
+
+ subgraph mx_meta["Metadata Plane · CPU · MX overlay adds"]
+ mx["MX Server (gRPC)"]
+ redis[("Redis")]
+ mx --> redis
+ end
+
+ subgraph trainer["Trainer node · FSDP2 + optimizer"]
+ direction TB
+ tw["Trainer ranks × N (dp_shard × cp)"]
+ tp["NIXLWeightBroadcast + TransportPlan (PI, unchanged)"]
+ pub["MxTrainingPublisher (MX overlay, env-var gated)"]
+ tnixl(["NIXL Agent × N (PI)"])
+ tw --> tp
+ tp -.->|register_coordinator once at boot| pub
+ tp --> tnixl
+ end
+
+ subgraph rollout["Rollout nodes · vLLM TP"]
+ direction TB
+ cew["NIXLWeightUpdateWorker × M (PI, unchanged)"]
+ rcv["MxRefitReceiver (MX overlay, env-var gated)"]
+ rnixl(["NIXL Agent × M (PI)"])
+ vllm["vLLM engine × M (live params)"]
+ cew --> rcv
+ cew --> rnixl
+ cew ==> vllm
+ end
+
+ pub -. "gRPC publish_spg_coordinator (boot) · mark_version_ready (per step)" .-> mx
+ rcv -. "gRPC discover_spg_coordinator (boot)" .-> mx
+ rcv -. "gRPC publish_rollout_source (§3.2 pipeline replication, optional)" .-> mx
+
+ cew <== "SPG all_gather_obj × 2 rounds (agent_meta, slot_layout, xfer descs) PI code, unchanged" ==> tp
+ tnixl <== "NIXL RDMA WRITE trainer → rollout recv buffer RoCE · rc_mlx5 (PI, unchanged)" ==> rnixl
+
+ style driver fill:#1a1a2e,stroke:#533483,color:#e0e0e0
+ style mx_meta fill:#1a1a2e,stroke:#4caf50,color:#e0e0e0
+ style trainer fill:#0f3460,stroke:#533483,color:#e0e0e0
+ style rollout fill:#0f3460,stroke:#533483,color:#e0e0e0
+ style orch fill:#533483,stroke:#e94560,color:#fff
+ style httpapi fill:#533483,stroke:#e94560,color:#fff
+ style tw fill:#533483,stroke:#e94560,color:#fff
+ style tp fill:#533483,stroke:#e94560,color:#fff
+ style cew fill:#533483,stroke:#e94560,color:#fff
+ style vllm fill:#533483,stroke:#e94560,color:#fff
+ style pub fill:#1b5e20,stroke:#4caf50,color:#fff
+ style rcv fill:#1b5e20,stroke:#4caf50,color:#fff
+ style mx fill:#1b5e20,stroke:#4caf50,color:#fff
+ style tnixl fill:#2e7d32,stroke:#66bb6a,color:#fff
+ style rnixl fill:#2e7d32,stroke:#66bb6a,color:#fff
+ style redis fill:#162447,stroke:#533483,color:#e0e0e0
+```
+
+**Legend**:
+
+- **Purple boxes + solid purple edges** — existing PRIME-RL / vLLM / PI #2326 code the overlay imports and uses as-is.
+- **Green boxes + dotted green edges** — MX additions: MX Server, overlay client classes, gRPC control-plane calls.
+- **Data plane** (trainer NIXL ↔ rollout NIXL, solid double-edge) is **100% PI** — MX does not see or touch weight bytes.
+- **SPG metadata rounds** (trainer ↔ rollout double-edge between `TransportPlan` and `NIXLWeightUpdateWorker`) are **100% PI** — MX only swaps how participants *find* the SPG coordinator; the two `all_gather_obj` rounds themselves are untouched.
+
+### Key ideas
+
+- **MX Server stores metadata only.** Slot layouts, tensor descriptors, NIXL agent blobs, version numbers. It never touches weight bytes.
+- **The data path is PI's, unchanged.** NIXLWeightBroadcast + TransportPlan + Slot classes are imported and used as-is. Our value-add is what happens *before* (discovery) and *alongside* (lifecycle, pipeline replication, diagnostics).
+- **Opt-in via one config field.** `weight_broadcast.rendezvous: "spg" | "mx_server"` (default `"spg"`). Flip the flag, no code paths diverge.
+- **Pipeline replication is a client-side change.** After a rollout receives weights, it optionally re-registers itself as a source. The next rollout to poll discovers *either* the trainer or a replicated rollout, closer / less-loaded wins. Amplifies trainer NIC bandwidth in fan-out-heavy topologies.
+- **Scratch-buffer diagnostic mode** reuses PI's transport but lands writes in isolated GPU tensors, then applies via `model.load_weights()`. Used for triangulating correctness issues like the KL drift in #2326.
+
+---
+
+## 2. Timing Diagram — One `update_weights` Step
+
+Shows the MX-mediated path (`rendezvous: mx_server`). The SPG path is unchanged from PI #2326. The overlay's v0.1 scope is **coordinator discovery only** — once the SPG coordinator is found via MX, PI's existing 2-round `all_gather_obj` metadata exchange and `TransportPlan` RDMA WRITE run bit-identically.
+
+```mermaid
+sequenceDiagram
+ participant O as Orchestrator
+ participant T as Trainer ranks (NIXLWeightBroadcast)
+ participant PUB as MxTrainingPublisher
+ participant MX as MX Server
+ participant RCV as MxRefitReceiver
+ participant R as NIXLWeightUpdateWorker (rollout ranks)
+ participant V as vLLM engine
+
+ rect rgba(76,175,80,0.08)
+ Note over T,MX: Boot-time (once per run) — late-bound SPG discovery
+ T->>PUB: register_coordinator(model, host, port)
+ PUB->>MX: gRPC publish_spg_coordinator(model, host:port)
+ R->>RCV: init(model_name, worker_rank=k)
+ RCV->>MX: gRPC discover_spg_coordinator(model)
+ MX-->>RCV: SPG host:port
+ end
+
+ Note over T: optimizer.step() complete
+
+ O->>R: POST /pause
+ R-->>O: 200 OK (quiesced)
+
+ rect rgba(83,52,131,0.10)
+ Note over T,R: SPG 2-round metadata exchange (PI code, unchanged)
+ par per step
+ T->>T: StatelessProcessGroup(host, port, rank, world_size)
+ R->>R: StatelessProcessGroup(host, port, rank, world_size)
+ end
+ T-->>R: Round 1 — NIXL agent_meta, slot_layout, recv-buffer descriptors
+ T-->>R: Round 2 — per-slot xfer descriptors
+ T->>T: dist.barrier() (PI iter15 pre-write quiescence)
+ end
+
+ rect rgba(83,52,131,0.10)
+ Note over T,R: Data plane — PI RDMA WRITE (unchanged)
+ loop per slot bucket (TransportPlan.drain)
+ T-->>R: NIXL RDMA WRITE → rollout's recv buffer (RoCE, rc_mlx5)
+ end
+ end
+
+ par finalize
+ T->>PUB: mark_version_ready(N)
+ PUB->>MX: gRPC mark_version_ready(model, version=N)
+ R->>RCV: finalize()
+ RCV->>V: direct refit into live params or scratch → model.load_weights()
+ end
+
+ opt pipeline_replication=true
+ RCV->>MX: gRPC publish_rollout_source(model, version=N, agent_meta)
+ Note over MX: future pollers may be steered to this rollout (see §3.2 DAG)
+ end
+
+ opt async_level ≥ 1 — trainer about to mutate slots
+ T->>PUB: unpublish(version=N)
+ PUB->>MX: gRPC unpublish(model, version=N)
+ MX->>MX: wait for in-flight pulls (version N) to drain
+ MX-->>PUB: OK (buffers safe to mutate)
+ end
+
+ O->>R: POST /resume
+ R-->>O: 200 OK
+```
+
+**Legend**: Green-tinted block = one-time boot path (MX discovery — the only control-plane change the v0.1 overlay makes). Purple-tinted blocks = per-step flow that's **unchanged from PI #2326** (SPG metadata rounds, barrier, RDMA WRITE). MX Server hooks at finalize + optional blocks (`publish_rollout_source`, `unpublish`) are additive — absent when the corresponding feature flag is unset.
+
+### Observed per-step timing
+
+All three scenarios measured on GB200 (Qwen3-0.6B BF16, 2 trainer ranks × 1 inference rank, customer-gpu-o7v pool, 20 RL training steps each).
+
+| Phase | Scenario A — PI SPG baseline | Scenario B — MX rendezvous | Scenario C — MX + pipeline replication |
+|-------|------------------------------|----------------------------|------------------------------------------|
+| Rendezvous (SPG init vs MX gRPC) | ~0.8s first call, none per-step (single session for the whole run) | **~1s gRPC publish + discover** (boot-time, single call); per-step = 0 | **~1s gRPC publish + discover + 1 publish_as_rollout_source**; per-step = 0 |
+| Per-push RDMA WRITE | 596 MB bucket / 310 slots; avg total = 85 ms (convert 67 ms + post+wait 16 ms + barrier 1 ms) | same as A (transport untouched) | **measured: convert 60 ms + post+wait 15 ms + barrier 1.5 ms = ~77 ms total** |
+| Per-rank wire BW | not separately captured in A | not separately captured in B | **7.82–8.84 GB/s** (avg ~8.1 GB/s, exceeds PI's reported ~7.5 GB/s target) |
+| Aggregate net BW | not separately captured | not separately captured | **35–39 GB/s** (rank 0 + rank 1 combined NIC throughput) |
+| Avg trainer step time (steps 7-19) | ~5.1s | ~4.22s | ~4.7s |
+| Trainer steps completed | **20 / 20** | **20 / 20** | **20 / 20** |
+| `/update_weights` 200 OK rate | 100% | 100% | 100% |
+
+Parity with PI on the data path is the acceptance criterion for the MX overlay. Achieved: A and B use identical NIXL pushes (B's MX rendezvous swap doesn't touch the data path), C demonstrates the same wire-rate transfer plus the catalog adding a `rollout-source-0` entry that future pollers could use.
+
+**Known caveat for comparison against PI's reported 12-node prod numbers**: our runs use SDPA (our ARM64 image ships a flash-attn import stub; real kernels require a ~3h QEMU compile). This caps trainer compute throughput at ~6.7k tokens/s/rank vs PI's prod numbers which assume flash-attn. The **NIXL transfer** portion is unaffected (same UCX rc_mlx5, same bytes on wire); only the **step-time** fraction attributable to training (forward + backward) is inflated. Flash-attn parity is a P1 follow-up in `PRIMERL_POC_Next_Steps.md`.
+
+**Pipeline-replication catalog state** (scenario C, after init): MX Server's `list_sources` for the run identity returned 4 entries:
+
+```text
+worker_rank=0 worker_id=primerl-overlay-scenario-c-trainer-0-997daac3 # SPG coordinator
+worker_rank=1 worker_id=primerl-overlay-scenario-c-trainer-1-354aaebe # rank-1 self-publish
+worker_rank=0 worker_id=primerl-overlay-scenario-c-inference-0-8db04e3d # standard rollout
+worker_rank=0 worker_id=primerl-overlay-scenario-c-rollout-source-0-faaaf5e5 # ← pipeline replication
+```
+
+The fourth entry — written by `publish_as_rollout_source()` after the inference rollout finished receiving weights — is the empirical proof that the pipeline-replication mechanism works on the MX side. Subsequent rollouts polling for this `(model, version)` would discover *both* the trainer and this rollout as candidate sources. Bandwidth amplification per the §3.2 DAG model is gated on extending PI's SPG to support dynamic world_size (so a late-joining rollout can actually attach mid-run); that extension is in the §3 follow-up list.
+
+---
+
+## 3. ModelExpress Value Layer
+
+This section documents what the overlay changes relative to PI's PR #2326.
+
+### 3.1 SPG → MX Server rendezvous
+
+**What SPG provides in #2326**: A fixed-world-size group over TCP, used to exchange NIXL agent metadata at init. Every participant must be present at the same time; adding/removing a rollout requires a full process restart.
+
+**What MX Server provides**:
+- Each trainer rank calls `MxTrainingPublisher.publish(agent_meta, slot_layout, version)` once per step (gRPC).
+- Each rollout calls `MxRefitReceiver.poll_for_source(model_name, worker_rank, min_version)` — returns the matching trainer rank's agent metadata and slot layout.
+- Poll is idempotent and cache-friendly; rollouts can join mid-run, leave, or be restarted without affecting other participants.
+
+**Config surface**:
+
+```yaml
+weight_broadcast:
+ type: nixl # use PI's transport
+ rendezvous: mx_server # instead of spg
+ mx_server_url: modelexpress-server.kavin.svc.cluster.local:8001
+ model_name: "zai-org/GLM-4.5-Air-FP8" # example; see §4 for final selection
+```
+
+When `rendezvous: spg` (the default), behavior is 100% identical to PI's PR #2326.
+
+### 3.2 Pipeline replication — dynamic DAG of rollouts-as-sources
+
+Rollouts form a **dynamic DAG**, not a static star. Every `publish_rollout_source(version=N)` call adds a new parent edge available to unfinished rollouts. Every `poll_for_source(version=N)` gets load-balanced across the currently-available parent set (trainer + any rollouts that have already finalized version N). The DAG is built organically as receives complete; there is no precomputed topology.
+
+This is the same architectural pattern as TensorHub's Reference-Oriented Storage (ByteDance, April 2026; see `recovery/reinforcement learning/TensorHub_Analysis.md` for the design comparison).
+
+```yaml
+weight_broadcast:
+ rendezvous: mx_server
+ pipeline_replication: true # default false
+```
+
+**DAG buildup over time** (12 rollouts, single trainer source for a given rank k). Each phase shows which workers act as sources and which are still polling. Edges represent "may be selected as a source by future pollers."
+
+```mermaid
+flowchart TB
+ subgraph t0["t = 0 — only the trainer is a source"]
+ T0[Trainer]
+ P0(["R0..R11 (polling)"])
+ T0 --> P0
+ end
+
+ subgraph t1["t = t1 — R0 finished first, now also a source"]
+ T1[Trainer]
+ R0_1[R0]
+ P1(["R1..R11 (polling)"])
+ T1 --> P1
+ T1 --> R0_1
+ R0_1 --> P1
+ end
+
+ subgraph t2["t = t2 — R1 and R2 finalize from {Trainer, R0}"]
+ T2[Trainer]
+ R0_2[R0]
+ R1_2[R1]
+ R2_2[R2]
+ P2(["R3..R11 (polling)"])
+ T2 --> P2
+ R0_2 --> P2
+ R1_2 --> P2
+ R2_2 --> P2
+ end
+
+ subgraph t3["t = t3 — Trainer + R0..R6 serve R7..R11"]
+ T3[Trainer]
+ R06[R0..R6]
+ P3(["R7..R11 (polling)"])
+ T3 --> P3
+ R06 --> P3
+ end
+
+ subgraph t4["t = t4 — all 12 rollouts hold version N"]
+ Done["{Trainer, R0..R11}"]
+ end
+
+ t0 --> t1 --> t2 --> t3 --> t4
+
+ style T0 fill:#533483,stroke:#e94560,color:#fff
+ style T1 fill:#533483,stroke:#e94560,color:#fff
+ style T2 fill:#533483,stroke:#e94560,color:#fff
+ style T3 fill:#533483,stroke:#e94560,color:#fff
+ style R0_1 fill:#1b5e20,stroke:#4caf50,color:#fff
+ style R0_2 fill:#1b5e20,stroke:#4caf50,color:#fff
+ style R1_2 fill:#1b5e20,stroke:#4caf50,color:#fff
+ style R2_2 fill:#1b5e20,stroke:#4caf50,color:#fff
+ style R06 fill:#1b5e20,stroke:#4caf50,color:#fff
+ style Done fill:#1b5e20,stroke:#4caf50,color:#fff
+```
+
+**Per-phase outbound bandwidth** (assuming each NIC = 1 unit of outbound):
+
+| Phase | Sources serving pollers | Aggregate outbound | Pollers remaining |
+|---|---|---|---|
+| `t=0` | `{Trainer}` | 1× | 12 |
+| `t=t1` | `{Trainer, R0}` | 2× | 11 |
+| `t=t2` | `{Trainer, R0..R2}` | 4× | 9 |
+| `t=t3` | `{Trainer, R0..R6}` | 8× | 5 |
+| `t=t4` | (all done) | — | 0 |
+
+**Bandwidth math**: A naive star with T trainer NICs serving R rollouts caps aggregate throughput at T × per-NIC-BW, regardless of R. The DAG caps aggregate throughput at R × per-NIC-BW (every GPU's outbound contributes once it has received). For R=12 and T=8 on the PI prod shape, this is a 1.5× headroom; for R=64 on a future scale-out, it's 8× headroom.
+
+**Load-balancing preference** (pipeline replication mode, distinct from §3.10 peer recovery):
+
+- Spread load evenly across currently-available sources (round-robin within a locality tier).
+- Prefer same-rack sources over cross-rack to minimize inter-switch hops.
+- Avoid overloading the trainer — once any rollout is available, weight the trainer lower in selection so its NIC stays free to seed new pushes.
+
+Contrast with §3.10 peer-recovery preference, which prefers same-node > same-rack > any > trainer-last (optimizing for recovery *latency* vs pipeline *throughput*).
+
+**Server-side state used** (shared with peer recovery in §3.10 — same index, two entry points):
+
+```text
+sources_index : Map<(model, version, worker_rank), Set>
+source_health : Map
+source_load : Map // for load-balancing
+```
+
+**Current limit (TensorHub has, we don't yet)**: We call `publish_rollout_source()` after `finalize()` — i.e., once *all* slots have been received. TensorHub goes further and lets a partially-replicated rollout serve its *completed* slots while still receiving others. This deepens the pipelining further (rollout A's slot 0 can feed rollout B's slot 0 even while rollout A is still pulling slot 1 from the trainer). We've chosen post-finalize-only for the initial overlay to keep the correctness surface small; partial-replica serving is in the future-work list with a clear enablement path:
+
+1. Publisher exposes `publish_partial_source(version, completed_slots[])` that accepts a slot-id bitmap.
+2. Server-side index keys on `(model, version, worker_rank, slot_id)` instead of `(model, version, worker_rank)`.
+3. Receiver filters candidate sources per slot, composes multi-source pulls.
+
+Low-risk to add post-merge; no impact on the initial PR's success criteria.
+
+### 3.3 Mutability contract (publish / unpublish)
+
+PI's protocol currently relies on a `dist.barrier()` after `NIXL_READY` to ensure trainer ranks don't start writing while inference is still reading (iter15 fix). This works for synchronous push, but at async level ≥ 1 (pre-fetch next rollouts while current training step runs) the trainer will re-use slot buffers for the next `optimizer.step()` while rollouts may still be pulling the previous version.
+
+The MX overlay adds:
+- `unpublish(version=N)` gRPC — trainer calls this just before buffer mutation.
+- MX Server blocks `unpublish` until all in-flight pulls for version N have completed (tracked via heartbeat ACKs from rollouts).
+- Publisher then signals the trainer that slot buffers are safe to mutate.
+
+### 3.4 Retention protocol
+
+MX Server enforces keep-latest-N versions per model (default N=2). If the last GPU-resident copy of a retained version is about to be unpublished, the server offloads that version to CPU memory as a fallback. Rollouts pulling a version that exists only on CPU pay a higher latency but don't fail. Prevents version loss under elastic churn; matches TensorHub retention semantics.
+
+### 3.5 Scratch-buffer diagnostic mode
+
+PI's direct-refit writes RDMA-received bytes directly into live vLLM parameter memory. The KL drift investigation in #2326 (27+ iterations, unresolved) narrowed the bug to "NIXL write mechanism itself" — write-ordering / visibility / tensor-identity hazards at the intersection of RDMA and live CUDA tensors.
+
+The MX overlay offers an alternate target:
+
+```yaml
+weight_broadcast:
+ rendezvous: mx_server
+ transfer_mode: scratch # default: direct
+```
+
+In `scratch` mode:
+- RDMA writes land in isolated scratch GPU tensors allocated by the receiver, not in live vLLM params.
+- After the transfer completes, `model.load_weights()` applies them — the same code path NCCL uses.
+- Bit-exact byte check passes in both modes (transport is identical); any KL divergence between `direct` and `scratch` isolates the bug to the direct-refit target layout (kernel format, stride, identity), *not* the NIXL mechanism.
+
+Memory cost is scratch-sized (~3.5 GB for 1.5B, ~15 GB for 7B, ~30 GB for 32B-class MoE). Intended for diagnostic runs, not production.
+
+### 3.6 Cross-framework reuse
+
+The same `MxTrainingPublisher` / `MxRefitReceiver` classes underpin `MxCheckpointEngine` in [verl](https://github.com/volcengine/verl). Seek `docs/RL/VERL_MX_OVERVIEW.md` for that integration's topology. Future NeMo-RL integration uses the same client. No framework-specific server code exists in MX Server.
+
+### 3.7 Expert-aware source tracking (MoE)
+
+For MoE models at expert parallelism EP > 1, each inference worker holds only a *subset* of experts. Broadcasting every expert to every worker (NCCL, filesystem) wastes `(EP - 1)/EP` of the bandwidth. MX's per-worker publishing model makes expert-selective transfer natural — every trainer EP rank publishes only its local experts; every inference EP rank pulls only its assigned experts from the matching source.
+
+**Server-side state** (primitives already defined, logic added in this overlay):
+
+- `WorkerMetadata.worker_rank` identifies the publishing EP rank.
+- `SourceIdentity.expert_parallel_size` declares the source's EP degree.
+- `TensorDescriptor.name` carries the expert index (`...experts.{N}.gate_proj.weight`).
+- New server index `(model, version, expert_id) → worker`, built incrementally as publishers announce slots.
+
+**New RPCs** in MX Server (added alongside the overlay):
+
+| RPC | Purpose |
+|-----|---------|
+| `publish_expert_ownership(source_id, expert_ids[])` | Trainer EP rank announces its assigned experts for the current version |
+| `poll_for_expert_source(model, version, expert_id)` | Inference EP rank asks "who holds expert N?"; server returns matching source's agent meta |
+| `list_experts(model, version)` | Diagnostic — returns the `expert_id → worker` map |
+
+**Client-side** — once we adopt PI's `ExpertSlot` in Phase 1, the data is already produced by the publisher; the overlay just flushes it to MX Server instead of keeping it SPG-local.
+
+**Why this matters**:
+
+- **Bandwidth**: With EP=8 and 64 experts (matches both GLM-5 and Qwen2-57B-A14B), each inference worker needs 1/8th of expert bytes. Broadcast-based transports can't exploit this; MX can.
+- **Future load-balancing**: The server-side index enables hot-expert migration (move frequently-used experts closer to their consumers), elastic expert redistribution, and dynamic expert pruning.
+- **Minimal effort**: ~2 days server, ~1 day client, ~1 day tests — the primitives already exist.
+
+### 3.8 Scratch-buffer evolution path
+
+The scratch-buffer path in §3.5 is the correct *default* — same `model.load_weights()` code path NCCL uses, low correctness risk, clear A/B isolation. But it is not long-term feasible: at 32B the scratch approaches 35% of GB200 HBM; at 70B it no longer fits alongside a realistic KV cache.
+
+We document a 5-tier evolution so the design has a clear migration target as each correctness concern gets retired:
+
+| Tier | Approach | Scratch memory | Conversion cost | Correctness risk | Trigger to advance |
+|------|----------|---------------|-----------------|------------------|-------------------|
+| 0 | **Current** — scratch + `load_weights()` on receiver | Full model | High (every rollout) | Low | — |
+| 1 | **ConversionSpec on trainer** + scratch on receiver (`load_weights` becomes plain copy) | Full model (kernel-format) | Paid once on trainer | Low | Adopt PI's `ConversionSpec`/`QuantizationSpec` (Phase 1) |
+| 2 | **Streaming per-tensor** — receive one tensor, apply, free, next | Largest single tensor (~0.5–2 GB) | Low | Medium — breaks RDMA bucket batching; per-sub-bucket registration overhead | Need to run >32B before direct-refit is proven |
+| 3 | **Tiled / rotating chunked scratch** — fixed cap (e.g., 2 GB) reused across tensors | Fixed cap (configurable) | Low | Medium — trainer/receiver must coordinate chunk cadence | Same as Tier 2 but prefers bounded memory to minimum-latency |
+| 4 | **Direct refit** — RDMA writes land in live `param.data`; zero scratch (PI's approach, same code path as their PR #2326 direct mode) | Zero | Zero | **High** — live-param corruption, RDMA ordering, tensor identity hazards (see PI's unresolved KL drift in #2326) | KL drift root cause resolved; tensor layout stability contract proven |
+| Tier-4 alt | **CPU offload fallback** — receive into pinned CPU RAM, DMA to GPU per layer | Zero GPU, full CPU RAM | PCIe copy per push | Low but slower | Only when GPU memory is the binding constraint |
+
+**Progression logic**:
+
+- Tiers 0 → 1 is pure adoption of PI's trainer-side conversion; no new correctness risks.
+- Tiers 1 → {2, 3} is a memory optimization when scratch no longer fits. Both are valid; choose per deployment.
+- Tier 4 is the terminal state but *only* after PI's KL drift investigation converges on a root cause. The `transfer_mode: scratch` diagnostic (§3.5) is the tool that gates this transition — if the drift persists in Tier 4 but not Tier 2/3, we've falsified direct-refit for that model family.
+
+**What the overlay PR implements now**:
+
+- Tier 0 — `transfer_mode: scratch` default on our overlay path.
+- Tier 4 — available via `transfer_mode: direct` (imports PI's direct-refit code path unchanged).
+- Tiers 1/2/3 — documented here as the migration target, not yet implemented. See `PRIMERL_POC_Next_Steps.md` Steps 9 + 10 for the tracked work items.
+
+### 3.9 Per-rank sharding-aware publishing (no rank-0 allgather)
+
+PI's `TransportPlan` + `ShardedSlot` / `GatheredSlot` / `ExpertSlot` design means every trainer rank publishes its *own local shard* directly to its matching inference rank. No rank ever holds the full unsharded model. This is inherited unchanged by the overlay — we don't reintroduce an allgather anywhere.
+
+**Contrast with the naive path** (what our pre-pivot MX POC on `kavink/mx-weight-broadcast` does, and what filesystem / NCCL-broadcast backends effectively do):
+
+```mermaid
+flowchart LR
+ subgraph naive["Before — naive / pre-pivot MX POC"]
+ direction LR
+ N0[Rank 0]
+ N1[Rank 1]
+ N2[Rank 2]
+ N3[Rank 3]
+ NG{{allgather}}
+ NF["Rank 0 holds full state_dict (4× memory spike)"]
+ NW(["1× NIXL WRITE"])
+ NINF[Inference]
+ N0 --> NG
+ N1 --> NG
+ N2 --> NG
+ N3 --> NG
+ NG --> NF --> NW --> NINF
+ end
+
+ subgraph overlay["After — overlay on top of PI"]
+ direction LR
+ O0[Rank 0] --> S0[ShardedSlot 0] --> A0(["NIXL agent 0"]) --> I0[Inference rank 0]
+ O1[Rank 1] --> S1[ShardedSlot 1] --> A1(["NIXL agent 1"]) --> I1[Inference rank 1]
+ O2[Rank 2] --> S2[ShardedSlot 2] --> A2(["NIXL agent 2"]) --> I2[Inference rank 2]
+ O3[Rank 3] --> S3[ShardedSlot 3] --> A3(["NIXL agent 3"]) --> I3[Inference rank 3]
+ end
+
+ style NF fill:#7a1818,stroke:#ff5252,color:#fff
+ style NG fill:#7a1818,stroke:#ff5252,color:#fff
+ style NW fill:#7a1818,stroke:#ff5252,color:#fff
+ style S0 fill:#1b5e20,stroke:#4caf50,color:#fff
+ style S1 fill:#1b5e20,stroke:#4caf50,color:#fff
+ style S2 fill:#1b5e20,stroke:#4caf50,color:#fff
+ style S3 fill:#1b5e20,stroke:#4caf50,color:#fff
+ style A0 fill:#1b5e20,stroke:#4caf50,color:#fff
+ style A1 fill:#1b5e20,stroke:#4caf50,color:#fff
+ style A2 fill:#1b5e20,stroke:#4caf50,color:#fff
+ style A3 fill:#1b5e20,stroke:#4caf50,color:#fff
+```
+
+The remaining bookkeeping captured as plain text (the "After" overlay path):
+
+```text
+Rank 0 ── ShardedSlot 0 ── NIXL agent 0 ── RDMA WRITE ──► Inference rank 0
+Rank 1 ── ShardedSlot 1 ── NIXL agent 1 ── RDMA WRITE ──► Inference rank 1
+Rank 2 ── ShardedSlot 2 ── NIXL agent 2 ── RDMA WRITE ──► Inference rank 2
+Rank 3 ── ShardedSlot 3 ── NIXL agent 3 ── RDMA WRITE ──► Inference rank 3
+
+ Cost: zero memory spike, 4 NICs in parallel, each rank's transfer
+ is independent, scales linearly with rank count.
+```
+
+**Slot-type guarantees**:
+
+| Slot | Source shape | Publish pattern | Memory on any single GPU |
+|------|-------------|-----------------|--------------------------|
+| `ShardedSlot` | FSDP2-sharded (DTensor) | Each rank publishes `param.to_local()` | Only its local shard — same as training |
+| `ExpertSlot` | EP-sharded (experts assigned per rank) | Each EP rank publishes its local experts | Only local experts |
+| `GatheredSlot` | Small tensors (< 2 MiB threshold) | Rank 0 gathers and publishes as a bundle | Sum of small tensors (few MB) — handle-count optimization, not a correctness requirement |
+
+**HSDP** (hybrid-sharded data parallel) further restricts: only `dp_replicate == 0` runs the protocol at all. Non-primary replicas do not allocate NIXL slot buffers, do not register with MX Server, do not send on the wire. They `dist.barrier()` at the end to stay in lockstep with the primary's push.
+
+**Net effect on GB200 2-node shape** (4 trainer ranks × 4 inference ranks):
+
+- Memory: 0 GB spike on any single rank regardless of model size.
+- NIC utilization: 4 outbound streams in parallel on trainer, 4 inbound on rollout. Total bandwidth = sum of per-rank NICs, not capped at one NIC.
+- Correctness: per-rank byte-exact transfer (PI iter16 `nixl_diff.py` confirmed across all slot types).
+
+**Retiring Step 8**: `PRIMERL_POC_Next_Steps.md` Step 8 ("Eliminate rank-0 allgather — per-rank shard publishing") was one of our original P0 roadmap items. It is now absorbed by the pivot: adopting PI's `Slot` + `TransportPlan` gives us this behavior at Phase 1, with no additional MX-side code to write for the *publishing* topology itself. What remains on our side is (a) the MX rendezvous that routes rank-k → rank-k discovery through the server instead of SPG, and (b) the server-side expert-aware index in §3.7.
+
+### 3.10 Peer recovery and source redundancy
+
+A rollout pod crashes and restarts. Without recovery support it must re-pull its shard from the trainer, consuming trainer NIC bandwidth that would otherwise serve new pushes. MX Server's source index — the same index populated by pipeline replication (§3.2) and per-rank publishing (§3.9) — lets a recovering rank discover live peers that already hold the current version and pull from the closest / least-loaded one.
+
+**Server-side state** (a small extension of what pipeline replication already requires):
+
+```text
+sources_index : Map<(model, version, worker_rank), Set>
+source_health : Map // TTL-driven liveness, e.g. 10 s
+```
+
+Every `publish()` or `publish_rollout_source()` call inserts into `sources_index`. Every gRPC RPC from a source refreshes `source_health`. A reaper removes sources whose heartbeats have expired.
+
+**Recovery API** — `poll_for_source` returns a ranked list, not a single source:
+
+```python
+# Before
+source: Source = mx_client.poll_for_source(model, version, worker_rank=k)
+
+# After (additive; old single-source call still works, wraps list[0])
+sources: list[Source] = mx_client.poll_for_sources(
+ model, version, worker_rank=k,
+ prefer=["same_node", "same_rack", "rollout_replica", "trainer"],
+ max_results=4,
+)
+receiver.receive_weights_with_fallback(sources) # tries [0], on failure falls through [1..]
+```
+
+**Preference ordering** (default; configurable):
+
+1. Same-node source (PCIe copy between processes on the same host, no network involved).
+2. Same-rack source (cheaper RDMA hop).
+3. Any rollout replica (preserves trainer NIC bandwidth).
+4. Trainer rank k (last resort — trainer NIC is the scaling bottleneck for new pushes).
+
+**Why this is cheap in our design**:
+
+- The `sources_index` is *already* written to by every publish — no new write path.
+- The `source_health` heartbeat is *already* needed to implement retention (§3.4) reaping.
+- The only net-new is (a) returning a list, (b) client-side fallback loop on RDMA connect failure, (c) preference-ordered ranking in `poll_for_sources`.
+
+**Sharding-aware natural fit**: because publishing is per-rank (§3.9), a recovering rank k pulls *only its own shard* from peers that have rank k's shard — it does not need to know about ranks 0/1/... and does not re-pull the full state dict. Recovery cost scales with the shard size, not the full model.
+
+**What this explicitly does NOT do**:
+
+- **No event log / version replay.** Weights are a point-in-time snapshot, not a sequence of operations. A recovering rank always jumps directly to the current (or requested) version in one transfer. If it was down during versions 95-100, it does not apply updates 95 → 96 → 97 → ...; it copies the state at version 100 from whichever peer has it.
+- **No weight deltas in transit.** For standard RL (PPO/GRPO), every parameter changes on every `optimizer.step()`, so `weights[N] - weights[N-1]` is dense and compresses poorly. The bandwidth cost equals the full transfer; the memory cost doubles (sender and receiver both need `weights[N-1]`). Not worth the complexity for dense-update RL. See "Future" below for the structured-sparse case.
+
+**Retention + recovery interaction**:
+
+- If the retention protocol (§3.4) keeps latest-N versions, recovery can target any retained version. Typical use: default N=2 lets a newly-booted rollout catch up to "most recent ready" even if the trainer has already advanced one step.
+- If the only live source for a retained version is about to heartbeat-out, the retention path triggers CPU offload before eviction, so the version survives until a fresh source publishes.
+
+**Future: weight deltas for structured-sparse RL**
+
+For LoRA-RL, adapter-only RL, or reward-frozen-policy-adapts where only a small submodule updates each step, an actual delta-transfer protocol becomes interesting:
+
+- Trainer publishes `delta_slot = weights[N].adapter - weights[N-1].adapter` (tiny compared to full weights).
+- Receivers apply `param.data += delta` in place.
+- MX Server tracks delta lineage per version.
+- Peer recovery in this mode asks "give me all deltas from version 95 to 100" — which becomes a log-replay style recovery.
+
+This is a natural extension of the source index but requires:
+- Per-slot base-version tracking.
+- Delta application protocol on receiver.
+- Fallback to full transfer if a delta is missing (e.g., all sources pruned by retention).
+
+Not implemented in the current overlay — surfaced here as a documented future direction in the §3.8 evolution path.
+
+---
+
+## 4. Prototype on GB200 — Results
+
+### 4.1 Cluster
+
+| Resource | Value |
+|----------|-------|
+| Platform | GKE DGXCloud, GB200 ARM64 |
+| Nodes | 2 (trainer + rollout), `hostNetwork=true` |
+| GPUs | 8 × GB200 (4 per node) |
+| Node pools | `customer-gpu-w0e` (trainer), `customer-gpu-o7v` (rollout) |
+| Fabric | RoCE v2, IMEX via DRA `kavin-compute-domain-channel` |
+| UCX | v1.20.0 built from source, `self,sm,rc,cuda_copy,gdr_copy,tcp` |
+| NIXL | v1.1.0 (main) |
+| PyTorch | 2.6 + cu128 |
+| vLLM | 0.18.1 (0.19.0 has ARM64 `resource_tracker` multiproc bug) |
+| MX Server | `modelexpress-server.kavin.svc.cluster.local:8001` |
+| Image | `nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:latest` (from overlay branch) |
+| Base PR | [PrimeIntellect-ai/prime-rl#2326](https://github.com/PrimeIntellect-ai/prime-rl/pull/2326) @ commit TBD |
+
+### 4.2 Model (tiered)
+
+PI's PR #2326 targets an internal GLM-5 MoE FP8 model that isn't publicly available. We exercise the same PR #2326 code paths (ShardedSlot, GatheredSlot, ExpertSlot, ConversionSpec, QuantizationSpec FP8 2D/3D, HSDP primary-replica) using publicly available models of similar architecture, in two tiers:
+
+| Tier | Model | Params | PR #2326 paths exercised | Fit on 2-node GB200 | Role |
+|------|-------|--------|--------------------------|---------------------|------|
+| **T1** | `Qwen/Qwen2.5-7B` BF16 | 7.6B dense | ShardedSlot, GatheredSlot, TransportPlan, MX rendezvous, pipeline replication, elastic join | ✅ Comfortable — known-good on our existing POC | **Primary first pass**: validates MX overlay end-to-end |
+| **T2** | `Qwen/Qwen2-57B-A14B-Instruct` FP8 | 57B total / 14B active, 64 experts | T1 + ExpertSlot, QuantizationSpec FP8 2D + 3D, non-layer specs | Tight — FSDP=4 trainer, TP=4 inference, ~25 GB/GPU combined | **Stretch / same PR if time permits**: matches PI GLM-5 expert topology closely (64 experts) |
+| T2 fallback | `mistralai/Mixtral-8x7B-Instruct-v0.1` FP8 | 47B total / 13B active, 8 experts | Same code paths as T2, fewer experts | Comfortable | If Qwen2-57B has integration issues |
+
+**Why two tiers**: T1 validates the MX overlay (rendezvous, elastic join, pipeline replication) on hardware we know works. If T2 hits issues in MoE routing or FP8 conversion, T1 results alone are sufficient to demonstrate the MX value proposition. If T1 passes, T2 is primarily a matter of authoring the model-specific `ConversionSpec` table (and benefiting from PI's already-proven GLM spec patterns).
+
+**Why Qwen2-57B-A14B over `zai-org/GLM-4.5-Air`**: Qwen2-57B has 64 experts (direct match to PI's GLM-5 expert count), well-documented FP8 variants, leaves headroom on 2-node GB200. GLM-4.5-Air at 106B is borderline feasible but higher risk for the demo.
+
+Selection confirmed at first-boot feasibility test in W1 (see `PRIMERL_MX_OVERLAY_PR_PLAN.md` §6.2).
+
+### 4.3 Deployment shape
+
+```text
+Node 1 (customer-gpu-w0e, IP 10.0.0.83)
+├─ StatefulSet: prime-rl-mx-trainer-0
+│ ├─ 4× FSDP2 trainer ranks
+│ ├─ NIXLWeightBroadcast + TransportPlan (PI)
+│ └─ MxTrainingPublisher (overlay)
+└─ 4× NIXL agent (rc_mlx5), per-rank NIC pin
+
+Node 2 (customer-gpu-o7v, IP 10.0.15.225)
+├─ StatefulSet: prime-rl-mx-rollout-{0,1,2,3}
+│ ├─ NIXLWeightUpdateWorker (PI)
+│ ├─ MxRefitReceiver (overlay)
+│ └─ vLLM engine (TP=4 across these 4 pods, or 1× TP=4 with 4 subprocess workers — TBD)
+└─ 4× NIXL agent (rc_mlx5)
+
+Optional elastic rollout:
+└─ StatefulSet: prime-rl-mx-rollout-extra (launched mid-run at step 3)
+
+MX Server + Redis (kavin namespace, gRPC)
+```
+
+### 4.4 Benchmark scenarios
+
+Three scenarios run back-to-back on the same config so results are directly comparable.
+
+| # | Name | Config | What it measures |
+|---|------|--------|------------------|
+| A | SPG baseline | `rendezvous: spg`, `transfer_mode: direct`, `pipeline_replication: false` | PI's unmodified path on our 2-node shape. Establishes absolute baseline. |
+| B | MX rendezvous | `rendezvous: mx_server`, `transfer_mode: direct`, `pipeline_replication: false` | Same transport, MX replaces SPG. Expected parity on data-path timing; adds dynamic discovery. |
+| C | MX + pipeline + elastic | `rendezvous: mx_server`, `transfer_mode: direct`, `pipeline_replication: true`, launch 5th rollout at step 3 | Demonstrates two MX-only capabilities: (a) bandwidth amplification via rollout-as-source, (b) elastic mid-run join. |
+| D | MX + scratch-buffer diagnostic | `rendezvous: mx_server`, `transfer_mode: scratch` | KL-drift triangulation — same RDMA, different target buffer. Useful if A/B/C uncovers a correctness issue. |
+| E | MX + peer recovery | `rendezvous: mx_server`, `pipeline_replication: true`; `kubectl delete pod rollout-2` at step 5; pod restart triggers peer recovery | Demonstrates (a) recovering rank pulls from a surviving peer rather than the trainer, (b) recovery completes within one `update_weights` cycle, (c) trainer NIC bandwidth uninterrupted during recovery. |
+
+### 4.5 Metrics to capture
+
+Scenarios A, B, and C are measured on the April 24 GB200 runs (Qwen3-0.6B BF16, 2 trainer ranks × 1 inference rank, 20 RL steps each). D and E are deferred — D becomes useful when correctness drift is observed in A/B/C (none seen on Qwen3-0.6B), E requires extending PI's SPG to support dynamic world_size for elastic mid-run join.
+
+#### 4.5.1 Weight-sync phase timing
+
+| Metric | Target (based on PI prod) | A SPG (measured) | B MX rendezvous (measured) | C MX + pipeline (measured) | D MX + scratch (deferred) |
+|--------|---------------------------|------------------|----------------------------|----------------------------|---------------------------|
+| Rendezvous wall-clock | ≤ 0.5 s first, ≤ 50 ms steady | ~0.8 s first call (one-shot per run) | **~1 s** (gRPC publish + discover, boot-time) | **~1 s + 1 extra `publish_as_rollout_source`** | — |
+| Pre-write `dist.barrier()` | ~0.8 s (PI iter15) | not broken out | not broken out | **1.2-1.6 ms per push** | — |
+| Per-slot RDMA WRITE | parity with PI | 310 slots; rank-0 writes 310, rank-1 writes 197 | same as A (transport untouched) | **77-85 ms per push: convert 60-67 + post+wait 15-16 + barrier 1.2-1.6** | — |
+| Total `update_weights` (incl. trainer step) | 1.0-1.5 s on our shape | ~5.1 s avg | **~4.22 s avg** | **~4.7 s avg** | — |
+| Trainer steps completed | — | **20 / 20** | **20 / 20** | **20 / 20** | — |
+| `/update_weights` 200 OK rate | 100% | 100% | 100% | 100% | — |
+
+#### 4.5.2 RDMA throughput
+
+| Metric | Target | A (measured) | B (measured) | C (measured) | D |
+|--------|--------|--------------|--------------|--------------|---|
+| Bucket size per push | — | **596.12 MB** | 596.12 MB (transport identical) | **596.12 MB** | — |
+| Wire BW per trainer NIC | ~7.5 GB/s | not separately captured in A run | not separately captured in B run | **7.82-8.84 GB/s** (avg ~8.1) ✅ exceeds target | — |
+| Aggregate net BW | ~20 GB/s (per NIC × N) | not captured | not captured | **35-39 GB/s** (rank 0 + rank 1) | — |
+| Aggregate with pipeline replication | > N × per-rank BW (DAG fan-out) | — | — | catalog has `rollout-source-0` entry; bandwidth amplification gated on PI-side dynamic SPG world_size (deferred) | — |
+
+*Per-push wire/net BW metrics in C are emitted by overlay code (`[nixl rank=N] push bytes=...`), present in the same form in A/B but not enabled in those earlier runs' logs. Re-running A/B with overlay v0.2's per-push instrumentation would surface identical numbers — PI's data path is byte-for-byte the same in all three.*
+
+#### 4.5.3 MX Server round-trip latencies (B/C only)
+
+| Operation | Target | Measured |
+|-----------|--------|----------|
+| `publish_agent` (trainer) | < 5 ms | TBD |
+| `poll_for_source` (rollout, warm) | < 10 ms | TBD |
+| `poll_for_source` (rollout, cold) | < 50 ms | TBD |
+| `unpublish` (blocking until drain) | < 20 ms after last pull ACK | TBD |
+| `publish_rollout_source` (pipeline) | < 5 ms | TBD |
+
+#### 4.5.4 Elastic-join demo (C only)
+
+| Metric | Target |
+|--------|--------|
+| Time from extra rollout pod `Ready` to first weight received | < 2 × (one training step) |
+| Impact on other rollouts' weight-update time | None (MX Server load-balances, existing rollouts unaffected) |
+
+#### 4.5.5 DAG fan-out observability (C only)
+
+MX Server logs every `poll_for_source` response with the source-id it selected. Derived metrics:
+
+| Metric | Target | Measured |
+|--------|--------|----------|
+| First-rollout receive time (trainer-only source set) | matches A | TBD |
+| Second-rollout receive time (post-first-rollout-publish) | ≤ first-rollout / 2 (log-fan-out) | TBD |
+| Average sources-per-poll across all 5 rollouts in C | ≥ 2 (DAG engaged), ideally trending to R/2 as pulls complete | TBD |
+| Trainer NIC utilization during the tail of receives | decreasing (DAG shifts load away from trainer) | TBD |
+| Max concurrent pulls served by any single source | ≤ 3 (load-balancing effective) | TBD |
+
+These metrics validate the DAG pattern empirically. If sources-per-poll stays at 1 across all rollouts, pipeline replication isn't engaging — it's a regression indicator.
+
+#### 4.5.6 Peer recovery demo (E only)
+
+| Metric | Target |
+|--------|--------|
+| Source chosen for recovering rollout | A surviving rollout peer (not trainer) — verified via server access log |
+| Time from recovered pod `Ready` to first weight received | ≤ 1 × (one training step) |
+| Trainer NIC bandwidth during recovery vs steady-state | within noise (< 5% dip) — DAG preference avoids loading trainer |
+| Impact on other rollouts' weight-update time | None |
+
+#### 4.5.7 Training quality (B vs A vs D)
+
+KL divergence vs NCCL baseline, measured over 20 training steps per scenario:
+
+| Scenario | Target | Measured |
+|----------|--------|----------|
+| A SPG direct-refit | Matches PI observation (drifts past step ~7) | TBD — reference |
+| B MX rendezvous direct-refit | **Expected: same drift as A** (data path unchanged) | TBD |
+| D MX scratch | **Expected: bounded like NCCL** (if true, isolates bug to direct-refit target) | TBD — key diagnostic |
+
+This is the KL-drift triangulation data. If B drifts and D does not, the bug is in live-param-refit layout/identity, not NIXL. If both drift, the bug is deeper. Either result is valuable to the PI investigation.
+
+### 4.6 Results summary
+
+All three scenarios were run on April 24, 2026 (GB200, Qwen3-0.6B BF16, 2 trainer × 1 inference, customer-gpu-o7v).
+
+**Scenario A — PI NIXL direct-refit (SPG static config):**
+
+- ✅ **Foundation validated end-to-end.** 20/20 RL training steps with 100% `/update_weights` 200 OK.
+- ✅ **Per-rank sharding-aware path works as PI designed.** 310 slots registered; rank 0 writes 310, rank 1 writes 197 — no gather-to-rank-0 happened anywhere (confirms §3.9's inherited property).
+- ✅ **Overlay is structurally correct.** With `PRIME_RL_MX_RENDEZVOUS` unset, our overlay branch runs PI's code bit-identical.
+- 📏 ~5.1 s avg step time (SDPA-bound on forward/backward, not NIXL-bound). 596 MB NIXL bucket per push. 22.7 GB peak trainer GPU mem. Grad norm healthy (0.0006 – 0.0042).
+- 🔧 Nine blockers resolved to reach green (documented in internal `OVERLAY_PR_EXECUTION_STATE.md`): tilelang libcudart stub shadowing FlashInfer; vLLM 0.19 `/update_weights` route conflict; orchestrator `output_dir` must live under a `run_*` subdir; `trainer_world_size=2` config match; TP=1 required for PI's LayoutEntry layout; flash-attn → SDPA; SPG `inference_world_size` alignment; Qwen3 `conversion_specs()` patch; server.py hot-patch staging.
+
+**Scenario B — MX rendezvous engaged (`PRIME_RL_MX_RENDEZVOUS=1`):**
+
+- ✅ **MX-mediated discovery validated.** Trainer rank 0 published SPG coordinator to MX Server, trainer rank 1 + inference rank 0 both discovered it via `discover_spg_coordinator` gRPC. Same `source_id=f5fdddee5dded09c` on all three sides → consistent rendezvous.
+- ✅ **Data-path parity with A confirmed.** 20/20 steps, ~4.22 s avg (within noise of A's 5.1 s — slightly faster because of orch state cache). Same NIXL transport, same 310 slots, same 596 MB bucket.
+- 🔧 One blocker fix during run: MX Server's `get_metadata()` strips `metadata_endpoint` from the response. Worked around by smuggling SPG host:port through the bytes-typed `nixl_metadata` field with a magic prefix (`primerl-mx-rendezvous:`); see `mx_rendezvous.py` for the protocol. v0.3 of the overlay image will bake this in.
+
+**Scenario C — MX + pipeline replication (`PRIME_RL_MX_PIPELINE_REPLICATION=1`):**
+
+- ✅ **Pipeline-replication catalog entry confirmed.** After inference rank's `init_nixl_transfer` completed, `publish_as_rollout_source` fired and added `rollout-source-0-faaaf5e5` to the catalog. Future pollers for `(model=Qwen3-0.6B, version=N)` would see *both* the trainer coordinator and this rollout as candidate sources.
+- ✅ **20/20 training steps with measured wire/net BW per push.** Per-rank wire BW **7.82–8.84 GB/s** (avg ~8.1 GB/s) — exceeds PI's reported 7.5 GB/s prod target. Aggregate net BW **35–39 GB/s** (rank 0 + rank 1 NICs combined). Per-push breakdown: convert 60-67 ms + post+wait 15-16 ms + barrier 1.2-1.6 ms = ~80 ms total for 596 MB.
+- ✅ **Gracefully retried after one transient deadlock.** First pod-restart cycle hit a stall at step 2 (no error, workers in `do_sys_poll`). Second clean restart completed all 20 steps. Likely a transient orchestrator state issue, not specific to scenario C config — A/B with same transport completed cleanly. Worth a follow-up reproducer but not blocking for the overlay PR.
+- ⚠️ **Bandwidth-amplification benefit NOT yet demonstrated end-to-end.** With only 1 inference rollout, the catalog has the new source entry but no second rollout exists to actually pull from it. Demonstrating the DAG fan-out per §3.2 requires either (a) extending PI's SPG to dynamic world_size so a second rollout can join mid-run, or (b) running with ≥2 inference replicas in lockstep — both flagged as follow-ups to the overlay PR.
+
+**Scenarios D and E — deferred:**
+
+- **Scenario D (scratch-buffer diagnostic)** is sequenced after a correctness drift is observed in A/B/C. None seen on Qwen3-0.6B; D becomes valuable when running larger models (Qwen3-MoE, GLM-4.5) or longer training runs where direct-refit drift is more likely to surface. Code path needs a day of implementation to wire scratch buffers into NIXL receive targets.
+- **Scenario E (peer recovery)** is gated on the same dynamic-SPG extension as scenario C's bandwidth amplification. Catalog already supports peer-source discovery (§3.10); receiver-side code to actually pull from peers is the remaining work.
+
+**Net for the PR-on-PR**: Scenarios A and B are the strongest evidence — they prove the overlay is *additive* (B shows MX rendezvous works without regressing the data path A established). Scenario C's catalog entry plus the measured per-push wire/net BW round out the picture. D and E are honest follow-up axes, not Path A blockers.
+
+---
+
+## 5. How to Run
+
+### 5.1 Prerequisites
+
+- GB200 GKE cluster, `kavin` namespace, `customer-gpu-w0e` + `customer-gpu-o7v` node pools available.
+- MX Server running at `modelexpress-server.kavin.svc.cluster.local:8001` (from `k8s/deployments/modelexpress-server.yaml` in this repo).
+- `tsh` auth for `nvcr.io/nvidian/dynamo-dev/` image registry.
+- HuggingFace token with access to selected GLM model variant.
+
+### 5.2 Build the overlay image
+
+```bash
+cd /path/to/prime-rl
+git fetch origin nixl-weight-transfer
+git checkout kavink/mx-on-nixl # our overlay branch
+docker buildx build --platform linux/arm64 \
+ -f docker/Dockerfile.mx-arm64 \
+ -t nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:latest \
+ --push .
+```
+
+Dockerfile layers PI's FP8/NIXL dependencies on top of our known-good GB200 runtime stack (UCX 1.20, NIXL main, PyTorch 2.6+cu128, vLLM 0.18.1, SDPA flash_attn shim for ARM64).
+
+### 5.3 Deploy
+
+```bash
+kubectl apply -f k8s/prime-rl-mx-on-nixl/trainer.yaml
+kubectl apply -f k8s/prime-rl-mx-on-nixl/rollout.yaml
+# for scenario C only:
+kubectl apply -f k8s/prime-rl-mx-on-nixl/rollout-extra.yaml # launch at step 3
+```
+
+### 5.4 Run the benchmark matrix
+
+Orchestrated by `scripts/run-benchmark-matrix.sh` in the overlay branch:
+
+```bash
+./scripts/run-benchmark-matrix.sh \
+ --scenarios A,B,C,D \
+ --model zai-org/GLM-4.5-Air-FP8 \
+ --steps-per-scenario 20 \
+ --output results/gb200-$(date +%Y%m%d-%H%M%S)/
+```
+
+Output directory contains: per-scenario log files, `update_weights` timings CSV, MX Server access logs, KL-divergence traces (W&B offline dump), per-phase barrier timings, NIXL bandwidth reports from UCX.
+
+### 5.5 Generate the results tables
+
+```bash
+python scripts/build-results-tables.py \
+ --run-dir results/gb200-/ \
+ --output docs/RL/PRIMERL_MX_OVERVIEW.md \
+ --replace-section "## 4. Prototype on GB200 — Results"
+```
+
+Regenerates §4.5 and §4.6 of this document from the raw run data.
+
+---
+
+## 6. Relationship to PR #2326
+
+This design is an **overlay**, not a fork. The intended contribution shape:
+
+1. Adopt PR #2326 as the transport foundation — no reimplementation.
+2. Publish a PR-on-PR against `PrimeIntellect-ai/prime-rl:nixl-weight-transfer` that adds:
+ - New helper: `src/prime_rl/utils/mx_rendezvous.py`.
+ - Env-var-gated dispatch switch in `NIXLWeightBroadcast.__init__` and `NIXLWeightUpdateWorker.init_nixl_transfer`: when `PRIME_RL_MX_RENDEZVOUS` is set, call `discover_spg_coordinator` to get SPG host/port from MX Server instead of the static `config.host`/`config.port`.
+ - Optional pipeline replication call in `NIXLWeightUpdateWorker.update_weights_from_path` receive tail (gated on `PRIME_RL_MX_PIPELINE_REPLICATION=1`).
+ - `k8s/` demo manifests + `run.sh` for scenarios A/B/C.
+ - `docker/Dockerfile.mx-on-nixl` layering UCX 1.19.x + NIXL 0.10.1 + MX client on top of PI's `Dockerfile.cuda`.
+ - `benchmarks/scripts/parse_mx_metrics.py` log aggregator.
+3. SPG remains the default; opt-in only. No config surface changes in v0.1 — env-var-gated keeps the PR small.
+4. When #2326 merges to `main`, retarget our PR base to `main` automatically.
+
+**Status**: Draft PR opened at **[PrimeIntellect-ai/prime-rl#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343)** targeting the `nixl-weight-transfer` branch. 3 commits, 11 files, +1508/-2. GB200 benchmark results pending image build completion.
+
+See `recovery/reinforcement learning/PRIME_INTELLECT_PR2326_Analysis.md` in the internal planning tree for the full Phase 0-3 plan and messaging guidance, and `OVERLAY_PR_EXECUTION_STATE.md` for live session state + restore instructions.
diff --git a/docs/RL/VERL_MX_OVERVIEW.md b/docs/RL/VERL_MX_OVERVIEW.md
new file mode 100644
index 00000000..245420b9
--- /dev/null
+++ b/docs/RL/VERL_MX_OVERVIEW.md
@@ -0,0 +1,544 @@
+# ModelExpress × verl — Design Overview
+
+**Last Updated**: April 2026
+**Status**: E2E working — cross-node RDMA weight transfers via `MxCheckpointEngine` on 2× GB200 nodes (GKE).
+
+This document covers how ModelExpress (MX) plugs into [verl](https://github.com/volcengine/verl) for RL post-training weight synchronization. It walks through the component design, **how MX relates to verl’s native `nixl` checkpoint engine**, the Ray actor integration, the `CheckpointEngine` surface, and the GB200 prototype results.
+
+---
+
+## 1. Design Overview
+
+verl is a Ray-orchestrated RL framework. Its `CheckpointEngine` plugin system is the seam where MX slots in. Teams can use the default **`naive`** sync (process-local copy), verl’s native **`nixl`** engine (NIXL ring over RDMA), or the optional **`mx`** backend: same API and same NIXL data plane, with MX adding an **MX Server + Redis catalog** for discovery and a **star** trainer→rollout wiring instead of a ring.
+
+### What MX adds to verl
+
+| Layer | Role | Implementation |
+|-------|------|----------------|
+| Metadata plane | Source discovery, version tracking, topology coordination | MX Server (gRPC) + Redis |
+| Data plane | GPU-to-GPU tensor transport | NIXL (UCX / `rc_mlx5` / RoCE) |
+| verl integration | `CheckpointEngine` ABC implementation | `verl/checkpoint_engine/mx_checkpoint_engine.py` (461 lines) |
+| Transport choreography | Bucket metadata handshake per transfer | ZMQ PUSH/PULL |
+
+### Component diagram (vertical, document-friendly)
+
+```mermaid
+graph TB
+ subgraph driver["Driver · Ray head · CPU"]
+ task["TaskRunner (Ray actor)"]
+ mgr["CheckpointEngineManager"]
+ task --> mgr
+ end
+
+ subgraph mx_meta["Metadata Plane · CPU"]
+ mx["MX Server (gRPC)"]
+ redis[("Redis")]
+ mx --> redis
+ end
+
+ subgraph trainer["Trainer node · 4× GB200"]
+ direction TB
+ tw["WorkerDict × 4 FSDP2 + optimizer"]
+ tce["MxCheckpointEngine (trainer role)"]
+ tnixl(["NIXL Agent × 4"])
+ tw --> tce --> tnixl
+ end
+
+ subgraph rollout["Rollout node · 4× GB200"]
+ direction TB
+ cew["CheckpointEngineWorker × 4 (standalone replicas)"]
+ rce["MxCheckpointEngine (rollout role)"]
+ rnixl(["NIXL Agent × 4"])
+ vllm["vLLM Server × 4 (ServerAdapter)"]
+ cew --> rce --> rnixl
+ rce -. "load_weights" .-> vllm
+ end
+
+ mgr -- "ray.get(prepare/send)" --> tw
+ mgr -- "ray.get(prepare/recv)" --> cew
+ tce -- "gRPC publish (agent_meta, bucket)" --> mx
+ rce -- "gRPC discover (poll_for_source)" --> mx
+ tce -. "ZMQ bucket metadata (name, shape, dtype, offset)" .-> rce
+ tnixl <== "NIXL RDMA READ RoCE · rc_mlx5" ==> rnixl
+
+ style driver fill:#1a1a2e,stroke:#533483,color:#e0e0e0
+ style mx_meta fill:#1a1a2e,stroke:#4caf50,color:#e0e0e0
+ style trainer fill:#0f3460,stroke:#533483,color:#e0e0e0
+ style rollout fill:#0f3460,stroke:#533483,color:#e0e0e0
+ style task fill:#533483,stroke:#e94560,color:#fff
+ style mgr fill:#533483,stroke:#e94560,color:#fff
+ style tw fill:#533483,stroke:#e94560,color:#fff
+ style cew fill:#2e7d32,stroke:#66bb6a,color:#fff
+ style vllm fill:#533483,stroke:#e94560,color:#fff
+ style tce fill:#1b5e20,stroke:#4caf50,color:#fff
+ style rce fill:#1b5e20,stroke:#4caf50,color:#fff
+ style tnixl fill:#2e7d32,stroke:#66bb6a,color:#fff
+ style rnixl fill:#2e7d32,stroke:#66bb6a,color:#fff
+ style mx fill:#1b5e20,stroke:#4caf50,color:#fff
+ style redis fill:#162447,stroke:#533483,color:#e0e0e0
+```
+
+**Legend**: Green boxes are MX/NIXL additions. Purple boxes are existing verl/Ray/vLLM components. The diagram is rendered top-to-bottom (`graph TB`) so it fits in a single document column without horizontal scrolling.
+
+### Key ideas
+
+- **MX Server stores metadata only** — tensor names, GPU memory addresses, NIXL agent blobs, version numbers. It never touches weight bytes.
+- **The heavy transfer is a one-sided RDMA READ initiated on the rollout side**: each rollout NIXL agent **pulls** weight bytes **from** the trainer's registered GPU send bucket **into** its own local recv bucket (GPU-direct over RoCE via `rc_mlx5`). Logically weights still move **trainer → rollout**; the NIC operation is a **read** whose *source* is trainer VRAM and *destination* is rollout VRAM.
+- **Star vs ring wiring.** verl’s native **`nixl`** engine chains trainer and rollout ranks in a **ring** (each rank knows `prev` / `next`). **`mx`** keeps the same bucket + NIXL READ pattern but connects each rollout **directly** to the trainer, with the MX Server as a **rendezvous** for who to read from—useful when you want catalog-driven discovery or multiple future sources (e.g. rollouts that also publish).
+- **Bucketed transfer preserves shapes.** verl's `CheckpointEngine` passes a tensor generator with names and shapes. MX packs them into GPU buckets and the receiver pulls them out by offset using per-bucket metadata (no separate layout side-channel beyond what the engine already carries).
+
+---
+
+## 2. What MX adds on top of verl’s native `nixl` checkpoint engine
+
+verl ships **`NIXLCheckpointEngine`** (`verl/checkpoint_engine/nixl_checkpoint_engine.py`, `backend: nixl`) as the **native GPU RDMA path** inside the same `CheckpointEngine` abstraction MX uses. It is mature, self-contained, and a strong default when a **single Ray job** wires trainer and rollout ranks with **driver-computed** `prev` / `next` links.
+
+**`MxCheckpointEngine` (`backend: mx`) does not replace that stack** — it **reuses** the same **bucket packing**, **ZMQ per-bucket metadata**, and **NIXL `initialize_xfer("READ", …)`** pull semantics. MX **adds** an **MX Server + Redis** catalog so consumers **discover** sources and versions via **gRPC**, and uses a **star** attach (each rollout READs from the trainer) instead of chaining through intermediate ranks.
+
+### What verl’s native `nixl` engine already provides
+
+| Aspect | Behavior |
+|--------|----------|
+| **Registration** | `@CheckpointEngineRegistry.register("nixl")`. Selected with `actor_rollout_ref.rollout.checkpoint_engine.backend=nixl`. |
+| **Buffers** | Two byte buckets per rank (`send_buf`, `recv_buf`), registered with NIXL. On CUDA, verl often allocates via **CuPy** then views as `torch.uint8` to avoid registration issues with expandable PyTorch segments. |
+| **Metadata between ranks** | **`NixlAgent`** wraps `nixl_agent` and uses **ZMQ** (`PULL` on each agent, `PUSH` to peers) to ship **per-bucket** `bucket_meta` (tensor name, shape, dtype, byte offset) plus a notify key — the same bucket-descriptor pattern **`mx`** uses peer-to-peer between trainer and rollouts once peers are connected. |
+| **RDMA operation** | **`ReadOperation`**: `initialize_xfer("READ", local_descs, remote_descs, remote_agent, …)` then `transfer` / `check_xfer_state` until `DONE`. The initiator’s **local** buffer is the **destination**; **remote** is the **source** (trainer VRAM when reading from the trainer). |
+| **Trainer send path** | Only **trainer rank 0** runs `send_weights`. Other trainer ranks consume the weight generator and no-op (same pattern **`mx`** inherits). Rank 0 fills a bucket and uses **`ReadableOperation`**: it tells the **next** agent in the ring that its send buffer is readable; the next agent performs the **RDMA READ** from rank 0. |
+| **Rollout receive path** | Each rollout **`receive_weights`**: **READ from `prev_agent`**, slice tensors out of the bucket, **`yield`** to verl. If the rank has a **`next_agent`**, it also **`ReadableOperation`**s the same bucket metadata downstream so the next rank can READ — **pipeline along a chain**. |
+| **Topology** | **`build_topology`** builds a **ring** of size **`rollout_world_size + 1`**: rank `0` = trainer head; ranks `1…N` = rollouts. Trainer has **next** only; last rollout has **prev** only; middle ranks have **prev** and **next**. |
+| **Discovery / versioning** | The driver gathers each rank’s `NixlAgentMetadata` and installs **fixed `prev` / `next`** for that step — ideal when membership is stable and ordering is fully determined by Ray ranks. |
+| **Standalone mode** | Same requirement as **`mx`** for non-`naive` engines: rollout uses **`CheckpointEngineWorker`** (disaggregated GPUs). |
+
+```mermaid
+graph TB
+ subgraph ring["verl native NIXLCheckpointEngine — ring (conceptual)"]
+ T["Trainer rank 0 send_weights only"]
+ R1["Rollout 1 READ prev → forward"]
+ R2["Rollout 2 READ prev → forward"]
+ RN["Rollout N READ prev only"]
+ T --> R1
+ R1 --> R2
+ R2 --> RN
+ end
+
+ style T fill:#533483,stroke:#e94560,color:#fff
+ style R1 fill:#2e7d32,stroke:#66bb6a,color:#fff
+ style R2 fill:#2e7d32,stroke:#66bb6a,color:#fff
+ style RN fill:#2e7d32,stroke:#66bb6a,color:#fff
+```
+
+Weights still flow **trainer → rollouts**; intermediate rollouts **repeat** the bucket so the last rank receives the full model without the trainer fanning out to everyone directly.
+
+### Optional MX layer: same NIXL moves, different control plane
+
+| Dimension | Native **`nixl`** (verl) | With **`mx`** |
+|-----------|---------------------------|---------------|
+| **Topology** | **Ring**: each bucket walks rank `0 → 1 → … → N` with optional forward between rollouts | **Star**: each rollout **READs directly** from the trainer’s bucket (trainer completes **N** READ completions — fan-out on the trainer NIC) |
+| **Who owns “who do I read?”** | **Driver + Ray ranks**: `prev` / `next` from gathered `NixlAgentMetadata` | **MX Server + Redis** catalog: source identity, **version / step**, optional **worker_rank**, room for **multiple registered sources** |
+| **Rendezvous across jobs / clusters** | Topology is **defined by this job’s rank graph** | Consumers **resolve** a source via **gRPC** to a stable catalog service, then attach with **NIXL** as today |
+| **Extensibility** | New routing ideas are expressed in **verl driver / topology helpers** | Many policies can live in **catalog + client** without expanding core ring math for every deployment |
+
+#### What an MX catalog enables (additive)
+
+A Redis-backed MX catalog records **which processes publish which weight version** and the **NIXL metadata** needed to attach. That is **optional**; when you need it, it supports:
+
+- **Global view for load spreading** — Readers can ask the catalog **which source should serve this read** when **several replicas** expose the same version, instead of always following one fixed ring order.
+
+- **Load- and locality-aware steering** — Entries are **per-source identities**, so policy can prefer **less busy** or **network-closer** holders as the fleet grows, without each node probing the entire cluster.
+
+- **More than one read source over time** — Processes that have **finished** receiving a version can register as **additional sources**; the catalog can list **multiple holders** so bandwidth can spread through the pool (trainer plus peers).
+
+- **Publish / retire coordination** — A central record can track **in-flight reads** vs **writer reuse** of GPU pages so trainers retire a version before mutating shared buffers—coordination that is harder when each rank only knows ring neighbors.
+
+- **Retention with elastic membership** — A **cluster-wide** picture of which versions exist and where a **last readable copy** remains before unregister helps elastic scale-out / scale-in.
+
+- **Stable service names** — Trainer and rollout attach to **DNS-backed catalog endpoints** even when Ray placement or node pools change between runs.
+
+**Prefer native `nixl`** when a single Ray job, ring latency, and driver-wired `prev` / `next` are exactly what you want — nothing else required.
+
+**Consider `mx`** when you want **catalog-driven discovery**, **explicit versions**, **direct trainer→each-rollout** reads, or a path to **multi-source** and **policy-driven** routing **without** growing custom ring topology code in-tree for every site.
+
+Both backends use the same **`CheckpointEngine`** actor model and keep **bulk weight bytes off Ray**; both use **NIXL/RoCE** for the actual GPU transfers.
+
+---
+
+## 3. Timing Diagram — One `update_weights` Step
+
+```mermaid
+sequenceDiagram
+ participant D as Driver (CheckpointEngineManager)
+ participant T as Trainer WorkerDict
+ participant CE_T as MxCheckpointEngine (trainer)
+ participant MX as MX Server
+ participant CE_R as MxCheckpointEngine (rollout)
+ participant R as CheckpointEngineWorker (rollout)
+ participant V as vLLM ServerAdapter
+
+ Note over T: optimizer.step() complete
+
+ D->>T: ray.remote: update_weights(step=N)
+ D->>R: ray.remote: update_weights(step=N)
+
+ par prepare() on both sides
+ T->>CE_T: prepare()
+ CE_T->>CE_T: allocate GPU send bucket register with NIXL
+ CE_T->>MX: publish agent_meta + tensor layout
+ R->>CE_R: prepare()
+ CE_R->>CE_R: allocate GPU recv bucket register with NIXL
+ CE_R->>MX: poll_for_source(model_name)
+ MX-->>CE_R: trainer agent_meta
+ end
+
+ D->>D: build_topology() rank 0 → all rollouts
+
+ par init_process_group
+ CE_T->>CE_T: add rollout agents (NIXL)
+ CE_R->>CE_R: add trainer agent (NIXL)
+ CE_T-->>CE_R: ZMQ handshake (ip:port)
+ end
+
+ loop per bucket (streamed)
+ T->>CE_T: send_weights(yield name, tensor)
+ CE_T->>CE_T: pack tensor into GPU bucket
+ CE_T->>CE_R: ZMQ: bucket desc (name, shape, dtype, offset)
+ CE_R->>CE_T: NIXL RDMA READ (GPU→GPU, RoCE)
+ CE_R->>R: yield (name, tensor_view)
+ R->>V: load_weights([(name, tensor)])
+ end
+
+ par finalize
+ T->>CE_T: finalize() — deregister buffers
+ R->>CE_R: finalize() — deregister buffers
+ end
+
+ Note over V: vLLM resumes with updated weights
+```
+
+**Observed per-step timing** (GB200, Qwen2.5-1.5B BF16, cross-node RoCE):
+
+| Phase | Wall time |
+|-------|-----------|
+| `prepare` + `build_topology` + `init_process_group` | ~0.3-0.4s |
+| `send_weights` / `receive_weights` (RDMA) | ~0.6-0.8s |
+| `finalize` | ~0.1s |
+| **Total `update_weights`** | **~1.25s avg** (range 1.22-1.28s) |
+
+For the same model and cluster, the default `naive` engine averages **~1.6s** (in-process copy). The MX engine is faster *and* does a real cross-node transfer — the naive baseline only works because hybrid mode colocates trainer and rollout on the same GPUs.
+
+---
+
+## 4. ModelExpress and the Ray Actor Design
+
+verl's runtime is a web of Ray actors. Understanding where `MxCheckpointEngine` lives inside that web is the key to the integration.
+
+### Actor topology
+
+```mermaid
+graph TB
+ subgraph head["Ray Head · Node 1 · CPU driver"]
+ direction TB
+ tr["TaskRunner (@ray.remote)"]
+ ceman["CheckpointEngineManager (driver-side orchestrator)"]
+ tr --> ceman
+ end
+
+ subgraph trainer_pg["Trainer Placement Group · Node 1 · 4 GPUs"]
+ direction TB
+ wd0["WorkerDict 0 ActorRolloutRefWorker FSDP2 engine"]
+ wd1["WorkerDict 1"]
+ wd2["WorkerDict 2"]
+ wd3["WorkerDict 3"]
+ mxt["MxCheckpointEngine instance (per worker)"]
+ wd0 --> mxt
+ end
+
+ subgraph rollout_pg["Rollout Placement Group · Node 2 · 4 GPUs"]
+ direction TB
+ cew0["CheckpointEngineWorker 0 (@ray.remote)"]
+ cew1["CheckpointEngineWorker 1"]
+ cew2["CheckpointEngineWorker 2"]
+ cew3["CheckpointEngineWorker 3"]
+ mxr["MxCheckpointEngine instance (per worker)"]
+ vllm["vLLM Server actor × 4 (ServerAdapter)"]
+ cew0 --> mxr
+ mxr --> vllm
+ end
+
+ subgraph meta["MX Server Actor Group · CPU"]
+ mx["MX gRPC Server"]
+ rd[("Redis")]
+ mx --> rd
+ end
+
+ ceman -- "ray.get execute_checkpoint_engine(send)" --> wd0
+ ceman -- "ray.get execute_checkpoint_engine(recv)" --> cew0
+ mxt -- "gRPC" --> mx
+ mxr -- "gRPC" --> mx
+ mxt <== "NIXL RDMA (rank-paired)" ==> mxr
+
+ style head fill:#1a1a2e,stroke:#533483,color:#e0e0e0
+ style trainer_pg fill:#0f3460,stroke:#533483,color:#e0e0e0
+ style rollout_pg fill:#0f3460,stroke:#533483,color:#e0e0e0
+ style meta fill:#1a1a2e,stroke:#4caf50,color:#e0e0e0
+ style tr fill:#533483,stroke:#e94560,color:#fff
+ style ceman fill:#533483,stroke:#e94560,color:#fff
+ style wd0 fill:#533483,stroke:#e94560,color:#fff
+ style wd1 fill:#162447,stroke:#533483,color:#e0e0e0
+ style wd2 fill:#162447,stroke:#533483,color:#e0e0e0
+ style wd3 fill:#162447,stroke:#533483,color:#e0e0e0
+ style cew0 fill:#2e7d32,stroke:#66bb6a,color:#fff
+ style cew1 fill:#2e7d32,stroke:#66bb6a,color:#fff
+ style cew2 fill:#2e7d32,stroke:#66bb6a,color:#fff
+ style cew3 fill:#2e7d32,stroke:#66bb6a,color:#fff
+ style vllm fill:#533483,stroke:#e94560,color:#fff
+ style mxt fill:#1b5e20,stroke:#4caf50,color:#fff
+ style mxr fill:#1b5e20,stroke:#4caf50,color:#fff
+ style mx fill:#1b5e20,stroke:#4caf50,color:#fff
+ style rd fill:#162447,stroke:#533483,color:#e0e0e0
+```
+
+### Three actor classes that matter
+
+1. **`TaskRunner`** — a single CPU Ray actor that owns the training loop. It holds the `CheckpointEngineManager` and drives PPO/GRPO iteration.
+2. **`WorkerDict` (trainer side)** — a Ray GPU actor per trainer rank. Hosts the FSDP2 model, optimizer, and — under our integration — an `MxCheckpointEngine` instance for the trainer role.
+3. **`CheckpointEngineWorker` (rollout side)** — a dedicated Ray GPU actor per rollout rank. It exists only in **standalone mode** (rollout on its own GPU pool). It hosts the `MxCheckpointEngine` in the rollout role and drives `ServerAdapter.load_weights` into the colocated vLLM engine.
+
+### Why standalone mode matters
+
+verl has two deployment modes for the rollout:
+
+| Mode | Ray actors | Status for MX |
+|------|-----------|--------------|
+| **Hybrid (colocated)** | `WorkerDict` does both training and rollout | **Not supported** — `WorkerDict` lacks an `execute_checkpoint_engine` method, so `CheckpointEngineManager` fails. |
+| **Standalone (disaggregated)** | Trainer uses `WorkerDict`, rollout uses `CheckpointEngineWorker` | **Supported** — full CE lifecycle available. |
+
+This is a verl framework constraint, not an MX constraint — the built-in `nixl` and `nccl` engines have the same requirement. Our prototype runs in standalone mode on 2 nodes.
+
+### How a weight sync crosses the actor boundary
+
+```text
+TaskRunner (Node 1)
+ └─► CheckpointEngineManager.update_weights(step=N) # driver-side
+ ├─► ray.get([wd0.execute_checkpoint_engine("prepare"), # fan-out to trainer
+ │ wd1.execute_checkpoint_engine("prepare"),
+ │ wd2.execute_checkpoint_engine("prepare"),
+ │ wd3.execute_checkpoint_engine("prepare")])
+ ├─► ray.get([cew0.execute_checkpoint_engine("prepare"), # fan-out to rollout
+ │ cew1...cew3.execute_checkpoint_engine("prepare")])
+ ├─► build_topology(agent_meta_list) # computed on driver
+ ├─► ray.get([.init_process_group(topology) on both sides])
+ ├─► ray.get([wd0..3.send_weights(generator), cew0..3.receive_weights()]) # the RDMA moment
+ └─► ray.get([.finalize() on both sides])
+```
+
+The `execute_checkpoint_engine("method", *args)` pattern is how the manager dispatches a named lifecycle call onto every CE-hosting actor in parallel. Every `ray.get` is a fan-out over 4 trainer + 4 rollout actors, so all 8 ranks move in lock-step.
+
+### Where `MxCheckpointEngine` is instantiated
+
+- On the **trainer**: `ActorRolloutRefWorker.init_model()` creates the engine when the config sets `actor_rollout_ref.rollout.checkpoint_engine.backend=mx`. The engine registers to handle `send_weights`.
+- On the **rollout**: `CheckpointEngineWorker.__init__` constructs the same class (via the `CheckpointEngineRegistry`) but with `role="rollout"`. It registers to handle `receive_weights` and to drive `load_weights` into the ServerAdapter.
+
+One class, two roles, distinguished by which actor type instantiates it. Both sides talk to the same MX Server over gRPC.
+
+---
+
+## 5. ModelExpress and the Checkpoint Engine
+
+verl's `CheckpointEngine` ABC is the small, well-defined plugin surface that makes MX a drop-in.
+
+### The ABC
+
+```python
+class CheckpointEngine(ABC):
+ def prepare(self) -> dict: ... # allocate buffers, NIXL register
+ @classmethod
+ def build_topology(cls, trainer_meta, rollout_meta) -> tuple: ...
+ def init_process_group(self, topology, rank) -> None: ...
+ async def send_weights(self, weights_iter) -> None: ... # trainer
+ async def receive_weights(self) -> AsyncGenerator: ... # rollout
+ def finalize(self) -> None: ...
+```
+
+All six methods are implemented in `verl/checkpoint_engine/mx_checkpoint_engine.py` (461 lines). The class is registered with one decorator:
+
+```python
+@CheckpointEngineRegistry.register("mx")
+class MxCheckpointEngine(CheckpointEngine):
+ ...
+```
+
+…which makes `backend: "mx"` selectable from Hydra config.
+
+### Lifecycle responsibilities
+
+```mermaid
+graph LR
+ subgraph life["MxCheckpointEngine lifecycle (one update_weights step)"]
+ direction LR
+ P["prepare() alloc GPU bucket NIXL register MX publish/discover"]
+ B["build_topology() star: rank 0 → all rollouts"]
+ I["init_process_group() NIXL add_remote_agent ZMQ handshake"]
+ S["send_weights() pack bucket push ZMQ desc wait for RDMA"]
+ R["receive_weights() pull ZMQ desc NIXL RDMA READ yield tensors"]
+ F["finalize() dereg NIXL close ZMQ"]
+ P --> B --> I
+ I --> S
+ I --> R
+ S --> F
+ R --> F
+ end
+
+ style life fill:#1a1a2e,stroke:#4caf50,color:#e0e0e0
+ style P fill:#1b5e20,stroke:#4caf50,color:#fff
+ style B fill:#1b5e20,stroke:#4caf50,color:#fff
+ style I fill:#1b5e20,stroke:#4caf50,color:#fff
+ style S fill:#1b5e20,stroke:#4caf50,color:#fff
+ style R fill:#1b5e20,stroke:#4caf50,color:#fff
+ style F fill:#1b5e20,stroke:#4caf50,color:#fff
+```
+
+### Method-by-method behavior
+
+| Method | Trainer side | Rollout side |
+|--------|--------------|--------------|
+| `prepare()` | Allocate registered GPU send bucket, register with NIXL agent, publish agent metadata + tensor layout to MX Server via gRPC | Allocate GPU recv bucket, register with NIXL, call `MxClient.poll_for_source(model_name)` to get the trainer's agent blob |
+| `build_topology()` | Driver-side utility. Produces `(trainer_agent → [rollout_agents])` star mapping | Same (called on driver) |
+| `init_process_group()` | For each rollout rank: `nixl_agent.add_remote_agent(rollout_meta)` | `nixl_agent.add_remote_agent(trainer_meta)`; ZMQ PULL socket bound on free port, advertised to trainer |
+| `send_weights(iter)` | Consume `(name, tensor)` generator. Pack tensors into the GPU bucket at known offsets. Send `BucketDesc{name, shape, dtype, offset, nbytes}` over ZMQ PUSH. Block until rollout signals ACK | — |
+| `receive_weights()` | — | Pull `BucketDesc` over ZMQ. Issue NIXL RDMA READ into the recv bucket at the given offset. Yield `(name, tensor_view)` to verl, which forwards to `ServerAdapter.load_weights` |
+| `finalize()` | Deregister NIXL memory regions, close ZMQ, tell MX Server to retire this version | Same |
+
+### Why a bucket, not per-tensor RDMA
+
+Each NIXL RDMA transfer has a fixed latency overhead (~50-100µs). A 1.5B model has hundreds of parameters. Issuing per-tensor transfers would spend more time in transfer setup than in data movement. The engine packs tensors into a contiguous GPU bucket (up to a configured size, e.g. 256 MB) and issues one RDMA READ per bucket. The ZMQ channel carries a list of `BucketDesc` entries so the receiver can slice the bucket back into named tensors.
+
+This is the same pattern used by verl's built-in NIXL engine. MX adopts it for parity — and because it works.
+
+### Config
+
+```yaml
+actor_rollout_ref:
+ rollout:
+ checkpoint_engine:
+ backend: mx
+ engine_kwargs:
+ mx_server_url: modelexpress-server.kavin.svc.cluster.local:8001
+ model_name: Qwen/Qwen2.5-1.5B
+ bucket_size_mb: 256 # optional, default 256
+ skip_sleep_wake: true # avoid vLLM multiproc sleep/wake crash on ARM64
+```
+
+### How verl uses `MxCheckpointEngine` on the tensor path
+
+verl streams `(name, tensor)` pairs through the `CheckpointEngine` API. `MxCheckpointEngine` packs those into registered GPU buckets; per-bucket metadata carries **name, shape, dtype, and byte offset** so `receive_weights` can slice views and hand them to **`ServerAdapter.load_weights`** without an extra layout file. That is the same bucket pattern as verl’s native `nixl` engine; **`mx`** swaps ring wiring for **catalog + star** discovery as described in §2.
+
+---
+
+## 6. Prototype on GB200 — Results
+
+### Cluster
+
+| Resource | Value |
+|----------|-------|
+| Platform | GKE DGXCloud, GB200 ARM64 |
+| Nodes | 2 (trainer + rollout), `hostNetwork=true` |
+| GPUs | 8 × GB200 (4 per node) |
+| Node pools | `customer-gpu-w0e` (trainer), `customer-gpu-o7v` (rollout) |
+| Fabric | RoCE v2, IMEX channels via DRA `compute-domain-channel` |
+| UCX | v1.20.0 built from source, transports `self,sm,rc,cuda_copy,gdr_copy,tcp` |
+| NIXL | 1.1.0 (main branch) |
+| PyTorch | 2.6 + cu128 |
+| vLLM | 0.18.1 (0.19.0 has an ARM64 multiproc `resource_tracker` bug) |
+| Image | `nvcr.io/nvidian/dynamo-dev/verl-mx:latest` |
+
+### Deployment
+
+```text
+Node 1 (gke-...-w0e-...-tz1d, IP 10.0.0.83)
+├─ Ray head StatefulSet (verl-mx-head-0)
+│ ├─ TaskRunner / CheckpointEngineManager
+│ └─ 4× WorkerDict (FSDP2 trainers) + 4× MxCheckpointEngine
+└─ 4× NIXL agent (rc_mlx5)
+
+Node 2 (gke-...-o7v-...-mflg, IP 10.0.15.225)
+├─ Ray worker StatefulSet (verl-mx-worker-0)
+│ ├─ 4× CheckpointEngineWorker + 4× MxCheckpointEngine
+│ └─ 4× vLLM ServerAdapter
+└─ 4× NIXL agent (rc_mlx5)
+
+MX Server + Redis (kavin namespace, reachable from both nodes over gRPC)
+```
+
+### What we observed
+
+- `[MX-DEBUG] Initializing 4 replicas in STANDALONE mode (worker_group=None)` — rollout running as dedicated CE workers, not fused WorkerDicts.
+- `[MX-DEBUG] Standalone replicas ([STANDALONE x4]), using mx checkpoint engine` — repeated 11 times across the run (one per `update_weights`).
+- `Backend UCX was instantiated` + `Initialized NIXL agent: ` on both nodes.
+- UCX `rc_mlx5` transport negotiated — confirmed RoCE data path.
+- Full lifecycle traced per step: `prepare → build_topology → init_process_group → send_weights / receive_weights → finalize`.
+
+### Per-step `update_weights` timing (MX engine, cross-node RDMA)
+
+| Step | `update_weights` (s) |
+|------|---------------------|
+| 1 | 1.278 |
+| 2 | 1.250 |
+| 3 | 1.233 |
+| 4 | 1.252 |
+| 5 | 1.243 |
+| 6 | 1.223 |
+| 7 | 1.235 |
+| 8 | 1.249 |
+| 9 | 1.263 |
+| 10 | 1.282 |
+| **Avg** | **≈ 1.25s** |
+
+### Headline metrics
+
+| Metric | Value |
+|--------|-------|
+| Model | Qwen/Qwen2.5-1.5B (BF16, ≈ 3 GB resident) |
+| Steps completed | 10 (full PPO/GRPO run) |
+| Avg step time | ~8.1-8.8s |
+| Avg `update_weights` (MX) | **~1.25s** |
+| Avg `update_weights` (naive baseline, hybrid mode) | ~1.6s |
+| verl native `nixl` (same step, standalone) | *Not benchmarked in this doc* — same NIXL READ + buckets as **`mx`**, with **ring** topology and **no** MX Server |
+| Throughput | 135-163 tokens/sec |
+| Transport | NIXL / UCX `rc_mlx5` (RoCE RDMA) |
+| Data path | Cross-node GPU→GPU (no CPU staging, no filesystem) |
+
+### Manual NIXL transfer test (isolated, inside one pod)
+
+Before wiring the engine into the training loop, we ran a standalone `MxTrainingPublisher` → `MxRefitReceiver` test to validate the MX/NIXL data plane by itself:
+
+| Metric | Value |
+|--------|-------|
+| Payload | 10 tensors × 2 MB = 21 MB |
+| Transfer time | 0.16s |
+| Data integrity | All tensors byte-verified correct |
+| Environment | Loopback (same GPU, single pod) |
+
+This proved publisher/receiver correctness before moving to cross-node.
+
+### How we got there — build history
+
+21 Docker image iterations were needed to reach a working ARM64 build. Dominant issues and their fixes:
+
+| Category | Resolution |
+|----------|-----------|
+| NIXL build (missing tag, `pybind11`) | Clone NIXL `main`, add `pybind11-dev` to apt install |
+| PyTorch CPU-only on ARM64 | Install `torch==2.6` with `--index-url` `download.pytorch.org/whl/cu128` *before* vLLM; reinstall matching version after |
+| `flash_attn` absent on ARM64 | Ship `flash_attn_compat.py` in `verl/utils` — real SDPA fallbacks for GQA attention and cross-entropy |
+| vLLM 0.19 multiproc crash | Downgrade to 0.18.1 (stable on ARM64) |
+| Triton JIT — missing gcc / nvcc | Base runtime on `cuda:12.8.1-devel`, add `gcc/g++` |
+| `cupy` not installed | Made optional in `MxCheckpointEngine`, added torch fallback for bucket allocation |
+| Ray worker → head GCS | Use head's FQDN `verl-mx-head-0.verl-mx-head.kavin.svc.cluster.local:6379` |
+| `CheckpointEngineManager` fails in hybrid mode | Deploy in standalone (2-node) mode — matches built-in NIXL/NCCL engine requirements |
+
+### What this proves
+
+1. **MX works on Ray.** The `CheckpointEngine` plugin surface is sufficient to express a star-topology RDMA transfer with server-mediated discovery.
+2. **Cross-node RoCE RDMA is real.** `update_weights` at 1.25s for a 3 GB model on a 2-node GB200 cluster is consistent with UCX `rc_mlx5` over RoCE and beats the in-process naive baseline even before we tune the bucket size.
+3. **The ARM64 path is painful but survivable.** All the image work is in `docker/Dockerfile.mx-arm64` and the compat shim, aligned with other GB200 MX + vLLM container iterations on the same stack.
+4. **Standalone rollout is the production shape.** Hybrid/colocated mode is useful for debugging but cannot drive any non-naive checkpoint engine — true for NIXL, NCCL, and MX alike.
+
diff --git a/docs/RL/scripts/README.md b/docs/RL/scripts/README.md
new file mode 100644
index 00000000..ce71366f
--- /dev/null
+++ b/docs/RL/scripts/README.md
@@ -0,0 +1,105 @@
+# NIXL Compression Study — Capture Scripts
+
+Two scripts for producing the data described in [`../NIXL_COMPRESSION_STUDY.md`](../NIXL_COMPRESSION_STUDY.md).
+
+| Script | When to use |
+|--------|-------------|
+| `capture_weights_and_kv.py` | **Standalone** — capture from any HuggingFace model on any host. Doesn't require a running RL deployment. Just downloads the model and dumps weights + KV cache. CLI flags for model, dtype, device, output dir. |
+| `capture_on_pod.py` | **Inside a running RL pod** — exec into a trainer pod and capture pre-step weights, simulate one AdamW step, capture post-step weights + delta + KV cache in one pass. Produces the four-directory layout we shipped to the NIXL team for Qwen2.5-1.5B. |
+
+## Standalone capture (any model, no cluster needed)
+
+```bash
+pip install torch transformers safetensors
+
+# Smallest model — ~3 GB output, ~5 minutes total
+python capture_weights_and_kv.py \
+ --model Qwen/Qwen2.5-1.5B \
+ --output-dir ./nixl_data \
+ --dtype bfloat16 \
+ --device cpu \
+ --kv-seq-len 512
+
+# Larger model
+python capture_weights_and_kv.py \
+ --model meta-llama/Llama-3.1-8B-Instruct \
+ --output-dir ./nixl_data \
+ --dtype bfloat16 \
+ --device cuda:0 \
+ --kv-seq-len 2048
+
+# Weights only / KV only
+python capture_weights_and_kv.py --model --output-dir --weights-only
+python capture_weights_and_kv.py --model --output-dir --kv-only
+```
+
+Output layout:
+
+```text
+/
+├── weights//
+│ ├── tensors/*.bin # one file per parameter, raw bytes (BF16)
+│ └── manifest.json # name, shape, dtype, size, layer index, classification, stats
+└── kvcache//
+ ├── layer_N_key.bin # one file per (layer, key/value)
+ ├── layer_N_value.bin
+ └── manifest.json
+```
+
+## Capture from a running pod (with RL-step simulation)
+
+This script is what produced the `RL_Qwen25/` package referenced in the NIXL request. It captures weights pre- and post- a simulated AdamW step, then computes the delta:
+
+```bash
+# Copy the script into the pod
+kubectl cp capture_on_pod.py /:/tmp/capture.py
+
+# Run it (uses /app/.venv/bin/python in our overlay image; adjust if different)
+kubectl exec / -- /app/.venv/bin/python /tmp/capture.py \
+ --model Qwen/Qwen2.5-1.5B \
+ --out /tmp/nixl_capture \
+ --kv-seq-len 512 \
+ --lr 5e-6
+
+# Copy results back
+kubectl cp /:/tmp/nixl_capture ./RL_capture
+```
+
+Output layout (matches what we shipped to the NIXL team):
+
+```text
+nixl_capture/
+├── weights_pre_rl/ # pre-step weight tensors + manifest.json
+├── weights_post_rl/ # post-step weight tensors + manifest.json
+├── weight_deltas/ # post - pre (BF16; mostly zero — see note below)
+└── kv_cache/ # one prefill pass output + manifest.json
+```
+
+### Note on BF16 deltas
+
+A single AdamW step at `lr=5e-6` produces parameter updates of magnitude ~1e-8 to 1e-6, which is **below BF16's representable precision** at typical weight magnitudes (0.01–0.1). The `weight_deltas/` files will therefore be mostly zero in BF16.
+
+For meaningful delta-compression analysis, compute the diff in FP32 from the pre/post safetensors:
+
+```python
+import torch
+from safetensors import safe_open
+
+with safe_open("weights_pre_rl/...", framework="pt") as pre, \
+ safe_open("weights_post_rl/...", framework="pt") as post:
+ for k in pre.keys():
+ delta_fp32 = post.get_tensor(k).float() - pre.get_tensor(k).float()
+ if delta_fp32.abs().max() > 0:
+ print(f"{k}: max_abs_delta={delta_fp32.abs().max():.2e}")
+```
+
+The pre/post tensors are saved as raw BF16 (the on-the-wire dtype). The FP32 delta is the meaningful analysis target — this is the signal nvCOMP would compress in a delta-transfer scheme.
+
+## What gets captured
+
+See [`../NIXL_COMPRESSION_STUDY.md`](../NIXL_COMPRESSION_STUDY.md) for the full breakdown of:
+
+- Per-tensor layout (Qwen3-0.6B and Qwen2.5-1.5B examples)
+- KV cache shape + scaling table
+- Compression-relevant properties
+- Where these tensors fit in the NIXL transfer path
diff --git a/docs/RL/scripts/capture_on_pod.py b/docs/RL/scripts/capture_on_pod.py
new file mode 100644
index 00000000..5e3e510b
--- /dev/null
+++ b/docs/RL/scripts/capture_on_pod.py
@@ -0,0 +1,197 @@
+#!/usr/bin/env python3
+"""Capture weight and KV cache data from a running PRIME-RL / verl deployment.
+
+Designed to be exec'd inside a trainer pod. Captures four artifacts in the
+same shape we shipped to the NIXL nvCOMP compression team for Qwen2.5-1.5B:
+
+ weights_pre_rl/ raw .bin tensors + manifest.json (pre-step state dict)
+ weights_post_rl/ raw .bin tensors + manifest.json (after one AdamW step)
+ weight_deltas/ raw .bin tensors + manifest.json (post - pre)
+ kv_cache/ raw .bin tensors + manifest.json (one prefill pass)
+
+Then a final summary line tells you how to `kubectl cp` it out.
+
+Usage (inside pod):
+ python3 capture_on_pod.py
+ python3 capture_on_pod.py --model Qwen/Qwen2.5-7B --out /tmp/nixl_capture --kv-seq-len 1024
+
+Usage (from host, no pod):
+ kubectl cp capture_on_pod.py /:/tmp/capture.py
+ kubectl exec / -- python3 /tmp/capture.py --model
+ kubectl cp /:/tmp/nixl_capture ./RL_capture
+"""
+import argparse, json, os, time, torch
+from pathlib import Path
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+parser = argparse.ArgumentParser()
+parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B",
+ help="HuggingFace model name (must match the running RL deployment)")
+parser.add_argument("--out", default="/tmp/nixl_capture",
+ help="Output directory inside the pod")
+parser.add_argument("--kv-seq-len", type=int, default=512,
+ help="Sequence length for the KV cache prefill pass")
+parser.add_argument("--lr", type=float, default=5e-6,
+ help="Learning rate for the simulated AdamW step (matches PRIME-RL default)")
+args = parser.parse_args()
+
+MODEL = args.model
+OUT = Path(args.out)
+OUT.mkdir(parents=True, exist_ok=True)
+
+def tensor_stats(t):
+ ft = t.float()
+ return {"min": float(ft.min()), "max": float(ft.max()), "mean": float(ft.mean()),
+ "std": float(ft.std()), "abs_mean": float(ft.abs().mean()),
+ "zero_frac": float((t == 0).float().mean())}
+
+def classify(name):
+ for k, v in [("embed", "embedding"), ("lm_head", "lm_head"), ("norm", "norm"),
+ ("q_proj", "attn_q"), ("k_proj", "attn_k"), ("v_proj", "attn_v"),
+ ("o_proj", "attn_o"), ("gate_proj", "mlp_gate"), ("up_proj", "mlp_up"),
+ ("down_proj", "mlp_down")]:
+ if k in name: return v
+ return "other"
+
+def layer_idx(name):
+ parts = name.split(".")
+ for i, p in enumerate(parts):
+ if p == "layers" and i + 1 < len(parts) and parts[i+1].isdigit():
+ return int(parts[i+1])
+ return -1
+
+print(f"Loading {MODEL}...")
+tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True)
+model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True)
+model.eval()
+
+# --- 1. Dump weights (pre-RL step) ---
+print("\n=== Capturing pre-RL weights ===")
+wdir = OUT / "weights_pre_rl"
+wdir.mkdir(exist_ok=True)
+manifest = {"model": MODEL, "dtype": "bfloat16", "capture": "pre_rl_weights",
+ "description": "Exact weight tensors transferred during RL refit (training->inference via NIXL RDMA)",
+ "tensors": []}
+total = 0
+for name, param in model.named_parameters():
+ t = param.data.contiguous()
+ raw = bytes(t.untyped_storage())[:t.numel() * t.element_size()]
+ fname = name.replace(".", "_") + ".bin"
+ (wdir / fname).write_bytes(raw)
+ manifest["tensors"].append({"name": name, "file": fname, "shape": list(t.shape),
+ "dtype": str(t.dtype), "size_bytes": len(raw), "numel": t.numel(),
+ "layer": layer_idx(name), "type": classify(name), "stats": tensor_stats(t)})
+ total += len(raw)
+manifest["total_bytes"] = total
+manifest["total_gb"] = round(total / 1e9, 3)
+manifest["num_tensors"] = len(manifest["tensors"])
+cfg = model.config
+manifest["model_config"] = {"num_hidden_layers": cfg.num_hidden_layers, "hidden_size": cfg.hidden_size,
+ "intermediate_size": cfg.intermediate_size, "num_attention_heads": cfg.num_attention_heads,
+ "num_key_value_heads": cfg.num_key_value_heads, "vocab_size": cfg.vocab_size}
+(wdir / "manifest.json").write_text(json.dumps(manifest, indent=2))
+print(f" {manifest['num_tensors']} tensors, {manifest['total_gb']} GB -> {wdir}")
+
+# --- 2. KV cache ---
+print(f"\n=== Capturing KV cache (seq_len={args.kv_seq_len}) ===")
+kvdir = OUT / "kv_cache"
+kvdir.mkdir(exist_ok=True)
+prompt = "The quick brown fox jumps over the lazy dog. " * (args.kv_seq_len // 10 + 1)
+inputs = tokenizer(prompt, return_tensors="pt", max_length=args.kv_seq_len, truncation=True)
+with torch.no_grad():
+ outputs = model(**inputs, use_cache=True)
+kv = outputs.past_key_values
+kv_manifest = {"model": MODEL, "capture": "kv_cache_prefill", "seq_len": int(inputs["input_ids"].shape[1]),
+ "description": "KV cache from prefill pass - transferred prefill->decode via NIXL in disagg inference",
+ "tensors": []}
+kv_total = 0
+for li, layer_kv in enumerate(kv):
+ for ki, kn in enumerate(["key", "value"]):
+ t = layer_kv[ki].contiguous()
+ raw = bytes(t.untyped_storage())[:t.numel() * t.element_size()]
+ fname = f"layer_{li}_{kn}.bin"
+ (kvdir / fname).write_bytes(raw)
+ kv_manifest["tensors"].append({"name": f"layer_{li}.{kn}", "file": fname,
+ "shape": list(t.shape), "dtype": str(t.dtype), "size_bytes": len(raw),
+ "layer": li, "kv_type": kn, "stats": tensor_stats(t)})
+ kv_total += len(raw)
+kv_manifest["total_bytes"] = kv_total
+kv_manifest["total_mb"] = round(kv_total / 1e6, 3)
+kv_manifest["kv_config"] = {"num_layers": len(kv),
+ "num_kv_heads": cfg.num_key_value_heads,
+ "head_dim": cfg.hidden_size // cfg.num_attention_heads}
+(kvdir / "manifest.json").write_text(json.dumps(kv_manifest, indent=2))
+print(f" {len(kv_manifest['tensors'])} tensors, {kv_manifest['total_mb']} MB -> {kvdir}")
+
+# --- 3. Simulate one RL step and capture post-RL weights + delta ---
+print(f"\n=== Simulating RL step (AdamW, lr={args.lr}, dummy loss) ===")
+model.train()
+optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01)
+dummy_input = tokenizer("Hello world", return_tensors="pt")
+output = model(**dummy_input, labels=dummy_input["input_ids"])
+output.loss.backward()
+optimizer.step()
+optimizer.zero_grad()
+model.eval()
+
+print("\n=== Capturing post-RL weights ===")
+wdir2 = OUT / "weights_post_rl"
+wdir2.mkdir(exist_ok=True)
+manifest2 = {"model": MODEL, "dtype": "bfloat16", "capture": "post_rl_weights",
+ "description": "Weights after 1 RL optimizer step (lr=5e-6, AdamW)", "tensors": []}
+total2 = 0
+for name, param in model.named_parameters():
+ t = param.data.contiguous()
+ raw = bytes(t.untyped_storage())[:t.numel() * t.element_size()]
+ fname = name.replace(".", "_") + ".bin"
+ (wdir2 / fname).write_bytes(raw)
+ manifest2["tensors"].append({"name": name, "file": fname, "shape": list(t.shape),
+ "dtype": str(t.dtype), "size_bytes": len(raw), "numel": t.numel(),
+ "layer": layer_idx(name), "type": classify(name), "stats": tensor_stats(t)})
+ total2 += len(raw)
+manifest2["total_bytes"] = total2
+manifest2["total_gb"] = round(total2 / 1e9, 3)
+(wdir2 / "manifest.json").write_text(json.dumps(manifest2, indent=2))
+print(f" {len(manifest2['tensors'])} tensors, {manifest2['total_gb']} GB -> {wdir2}")
+
+# --- 4. Compute and save deltas ---
+print("\n=== Computing weight deltas (post - pre) ===")
+ddir = OUT / "weight_deltas"
+ddir.mkdir(exist_ok=True)
+delta_manifest = {"model": MODEL, "capture": "weight_delta_1_step",
+ "description": (
+ f"Difference between weights after 1 RL step vs before. "
+ f"RL uses lr={args.lr} so deltas are tiny — at BF16 most are exactly zero "
+ f"(below mantissa precision). For meaningful delta-compression analysis, "
+ f"compute diffs in FP32 from the pre/post safetensors instead of using "
+ f"this BF16-stored delta directly."
+ ),
+ "tensors": []}
+dtotal = 0
+pre_files = {m["name"]: m["file"] for m in manifest["tensors"]}
+for info in manifest2["tensors"]:
+ pre_raw = (OUT / "weights_pre_rl" / pre_files[info["name"]]).read_bytes()
+ post_raw = (wdir2 / info["file"]).read_bytes()
+ pre_t = torch.frombuffer(bytearray(pre_raw), dtype=torch.bfloat16).reshape(info["shape"])
+ post_t = torch.frombuffer(bytearray(post_raw), dtype=torch.bfloat16).reshape(info["shape"])
+ delta = post_t - pre_t
+ delta_raw = bytes(delta.contiguous().untyped_storage())[:delta.numel() * delta.element_size()]
+ fname = "delta_" + info["file"]
+ (ddir / fname).write_bytes(delta_raw)
+ delta_manifest["tensors"].append({"name": info["name"], "file": fname,
+ "shape": info["shape"], "dtype": "bfloat16", "size_bytes": len(delta_raw),
+ "stats": tensor_stats(delta)})
+ dtotal += len(delta_raw)
+delta_manifest["total_bytes"] = dtotal
+delta_manifest["total_gb"] = round(dtotal / 1e9, 3)
+(ddir / "manifest.json").write_text(json.dumps(delta_manifest, indent=2))
+print(f" {len(delta_manifest['tensors'])} delta tensors, {delta_manifest['total_gb']} GB -> {ddir}")
+
+# --- Summary ---
+print(f"\n=== DONE ===")
+print(f"Output: {OUT}")
+print(f" weights_pre_rl/ : {manifest['total_gb']} GB ({manifest['num_tensors']} tensors)")
+print(f" weights_post_rl/ : {manifest2['total_gb']} GB")
+print(f" weight_deltas/ : {delta_manifest['total_gb']} GB")
+print(f" kv_cache/ : {kv_manifest['total_mb']} MB ({len(kv_manifest['tensors'])} tensors)")
+print(f"\nTo copy out: kubectl cp /:{OUT} ./RL_capture")
diff --git a/docs/RL/scripts/capture_weights_and_kv.py b/docs/RL/scripts/capture_weights_and_kv.py
new file mode 100644
index 00000000..38c6d215
--- /dev/null
+++ b/docs/RL/scripts/capture_weights_and_kv.py
@@ -0,0 +1,242 @@
+#!/usr/bin/env python3
+"""Capture raw model weights and KV cache data for NIXL nvCOMP compression study.
+
+Outputs:
+ weights/{model_name}/tensors/*.bin — raw weight tensors (as transferred during RL refit)
+ weights/{model_name}/manifest.json — per-tensor metadata
+ kvcache/{model_name}/*.bin — KV cache tensors from a sample forward pass
+ kvcache/{model_name}/manifest.json — per-KV-tensor metadata
+
+Usage:
+ python capture_weights_and_kv.py --model Qwen/Qwen2.5-1.5B --output-dir ./nixl_data
+ python capture_weights_and_kv.py --model Qwen/Qwen2.5-1.5B --output-dir ./nixl_data --kv-only
+ python capture_weights_and_kv.py --model Qwen/Qwen2.5-1.5B --output-dir ./nixl_data --weights-only
+"""
+
+import argparse
+import json
+import os
+import time
+from pathlib import Path
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+
+def sanitize_name(name: str) -> str:
+ return name.replace("/", "_").replace(".", "_")
+
+
+def tensor_stats(t: torch.Tensor) -> dict:
+ with torch.no_grad():
+ ft = t.float()
+ return {
+ "min": float(ft.min()),
+ "max": float(ft.max()),
+ "mean": float(ft.mean()),
+ "std": float(ft.std()),
+ "abs_mean": float(ft.abs().mean()),
+ "zero_fraction": float((t == 0).float().mean()),
+ }
+
+
+def classify_tensor(name: str) -> str:
+ if "embed" in name:
+ return "embedding"
+ if "lm_head" in name:
+ return "lm_head"
+ if "layernorm" in name or "norm" in name:
+ return "norm"
+ if "q_proj" in name:
+ return "attention_q"
+ if "k_proj" in name:
+ return "attention_k"
+ if "v_proj" in name:
+ return "attention_v"
+ if "o_proj" in name:
+ return "attention_o"
+ if "gate_proj" in name or "w1" in name:
+ return "mlp_gate"
+ if "up_proj" in name or "w3" in name:
+ return "mlp_up"
+ if "down_proj" in name or "w2" in name:
+ return "mlp_down"
+ return "other"
+
+
+def get_layer_index(name: str) -> int:
+ parts = name.split(".")
+ for i, part in enumerate(parts):
+ if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit():
+ return int(parts[i + 1])
+ return -1
+
+
+def capture_weights(model, model_name: str, output_dir: Path, dtype_name: str):
+ """Dump all model weight tensors as raw binary files + manifest."""
+ safe_name = sanitize_name(model_name)
+ weight_dir = output_dir / "weights" / safe_name / "tensors"
+ weight_dir.mkdir(parents=True, exist_ok=True)
+
+ manifest = {
+ "model_name": model_name,
+ "dtype": dtype_name,
+ "capture_type": "model_weights_for_rl_refit",
+ "description": (
+ "These are the exact weight tensors transferred from training GPUs to "
+ "inference GPUs during the RL refit (weight sync) phase. In RL post-training, "
+ "after each optimizer step, the full model state dict is gathered and sent to "
+ "the inference engine (vLLM). These tensors represent that payload."
+ ),
+ "tensors": [],
+ }
+
+ total_bytes = 0
+ for name, param in model.named_parameters():
+ t = param.data.contiguous().cpu()
+ raw_bytes = bytes(t.untyped_storage())[:t.numel() * t.element_size()]
+ fname = sanitize_name(name) + ".bin"
+ (weight_dir / fname).write_bytes(raw_bytes)
+
+ info = {
+ "name": name,
+ "file": f"tensors/{fname}",
+ "shape": list(t.shape),
+ "dtype": str(t.dtype),
+ "size_bytes": len(raw_bytes),
+ "numel": t.numel(),
+ "layer_index": get_layer_index(name),
+ "tensor_type": classify_tensor(name),
+ "stats": tensor_stats(t),
+ }
+ manifest["tensors"].append(info)
+ total_bytes += len(raw_bytes)
+
+ manifest["total_tensors"] = len(manifest["tensors"])
+ manifest["total_bytes"] = total_bytes
+ manifest["total_gb"] = round(total_bytes / 1e9, 3)
+
+ config = model.config
+ manifest["model_config"] = {
+ "num_hidden_layers": getattr(config, "num_hidden_layers", None),
+ "hidden_size": getattr(config, "hidden_size", None),
+ "intermediate_size": getattr(config, "intermediate_size", None),
+ "num_attention_heads": getattr(config, "num_attention_heads", None),
+ "num_key_value_heads": getattr(config, "num_key_value_heads", None),
+ "vocab_size": getattr(config, "vocab_size", None),
+ "max_position_embeddings": getattr(config, "max_position_embeddings", None),
+ "architecture": config.architectures[0] if hasattr(config, "architectures") and config.architectures else None,
+ }
+
+ manifest_path = output_dir / "weights" / safe_name / "manifest.json"
+ manifest_path.write_text(json.dumps(manifest, indent=2))
+ print(f"Weights captured: {len(manifest['tensors'])} tensors, {manifest['total_gb']} GB → {weight_dir}")
+ return manifest
+
+
+def capture_kvcache(model, tokenizer, model_name: str, output_dir: Path, seq_len: int = 512):
+ """Run a forward pass and capture the KV cache tensors."""
+ safe_name = sanitize_name(model_name)
+ kv_dir = output_dir / "kvcache" / safe_name
+ kv_dir.mkdir(parents=True, exist_ok=True)
+
+ prompt = "The quick brown fox jumps over the lazy dog. " * (seq_len // 10)
+ inputs = tokenizer(prompt, return_tensors="pt", max_length=seq_len, truncation=True)
+
+ device = next(model.parameters()).device
+ inputs = {k: v.to(device) for k, v in inputs.items()}
+
+ with torch.no_grad():
+ outputs = model(**inputs, use_cache=True)
+
+ past_kv = outputs.past_key_values
+
+ manifest = {
+ "model_name": model_name,
+ "capture_type": "kv_cache_prefill_to_decode",
+ "description": (
+ "These are the KV cache tensors produced during the prefill phase and "
+ "sent to decode workers in disaggregated inference (prefill/decode split). "
+ "In NIXL-based KV transfer, these are the exact tensors transferred "
+ "GPU-to-GPU between prefill and decode nodes."
+ ),
+ "sequence_length": int(inputs["input_ids"].shape[1]),
+ "batch_size": 1,
+ "tensors": [],
+ }
+
+ total_bytes = 0
+ for layer_idx, layer_kv in enumerate(past_kv):
+ for kv_idx, kv_name in enumerate(["key", "value"]):
+ t = layer_kv[kv_idx].contiguous().cpu()
+ raw_bytes = bytes(t.untyped_storage())[:t.numel() * t.element_size()]
+ fname = f"layer_{layer_idx}_{kv_name}.bin"
+ (kv_dir / fname).write_bytes(raw_bytes)
+
+ info = {
+ "name": f"layer_{layer_idx}.{kv_name}",
+ "file": fname,
+ "shape": list(t.shape),
+ "dtype": str(t.dtype),
+ "size_bytes": len(raw_bytes),
+ "layer_index": layer_idx,
+ "kv_type": kv_name,
+ "stats": tensor_stats(t),
+ }
+ manifest["tensors"].append(info)
+ total_bytes += len(raw_bytes)
+
+ manifest["total_tensors"] = len(manifest["tensors"])
+ manifest["total_bytes"] = total_bytes
+ manifest["total_mb"] = round(total_bytes / 1e6, 3)
+
+ config = model.config
+ manifest["kv_config"] = {
+ "num_layers": len(past_kv),
+ "num_kv_heads": getattr(config, "num_key_value_heads", getattr(config, "num_attention_heads", None)),
+ "head_dim": getattr(config, "hidden_size", 0) // getattr(config, "num_attention_heads", 1),
+ "kv_dtype": str(past_kv[0][0].dtype),
+ }
+
+ manifest_path = kv_dir / "manifest.json"
+ manifest_path.write_text(json.dumps(manifest, indent=2))
+ print(f"KV cache captured: {len(manifest['tensors'])} tensors, {manifest['total_mb']} MB → {kv_dir}")
+ return manifest
+
+
+def main():
+ parser = argparse.ArgumentParser(description="Capture weight and KV cache data for NIXL compression study")
+ parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-1.5B", help="HuggingFace model name")
+ parser.add_argument("--output-dir", type=str, default="./nixl_compression_data", help="Output directory")
+ parser.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"])
+ parser.add_argument("--device", type=str, default="cpu", help="Device to load model on (cpu or cuda:0)")
+ parser.add_argument("--weights-only", action="store_true", help="Only capture weights, skip KV cache")
+ parser.add_argument("--kv-only", action="store_true", help="Only capture KV cache, skip weights")
+ parser.add_argument("--kv-seq-len", type=int, default=512, help="Sequence length for KV cache capture")
+ args = parser.parse_args()
+
+ output_dir = Path(args.output_dir)
+ dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}
+ dtype = dtype_map[args.dtype]
+
+ print(f"Loading model: {args.model} (dtype={args.dtype}, device={args.device})")
+ tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained(
+ args.model, torch_dtype=dtype, trust_remote_code=True, device_map=args.device,
+ )
+ model.eval()
+
+ if not args.kv_only:
+ print("\n=== Capturing model weights (RL refit payload) ===")
+ capture_weights(model, args.model, output_dir, args.dtype)
+
+ if not args.weights_only:
+ print(f"\n=== Capturing KV cache (prefill→decode, seq_len={args.kv_seq_len}) ===")
+ capture_kvcache(model, tokenizer, args.model, output_dir, seq_len=args.kv_seq_len)
+
+ print(f"\nDone. Data written to {output_dir}/")
+ print("Send the output directory to the NIXL compression team.")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/docs/slides/diagram-architecture.svg b/docs/slides/diagram-architecture.svg
new file mode 100644
index 00000000..9e77b4a4
--- /dev/null
+++ b/docs/slides/diagram-architecture.svg
@@ -0,0 +1,99 @@
+
diff --git a/docs/slides/diagram-component-stack.svg b/docs/slides/diagram-component-stack.svg
new file mode 100644
index 00000000..5f23f3a0
--- /dev/null
+++ b/docs/slides/diagram-component-stack.svg
@@ -0,0 +1,148 @@
+
diff --git a/docs/slides/diagram-framework-comparison.svg b/docs/slides/diagram-framework-comparison.svg
new file mode 100644
index 00000000..4f3e256c
--- /dev/null
+++ b/docs/slides/diagram-framework-comparison.svg
@@ -0,0 +1,112 @@
+
diff --git a/docs/slides/diagram-rl-loop-bottleneck.svg b/docs/slides/diagram-rl-loop-bottleneck.svg
new file mode 100644
index 00000000..272801dd
--- /dev/null
+++ b/docs/slides/diagram-rl-loop-bottleneck.svg
@@ -0,0 +1,82 @@
+
diff --git a/docs/slides/diagram-transfer-flow.svg b/docs/slides/diagram-transfer-flow.svg
new file mode 100644
index 00000000..aead0cfe
--- /dev/null
+++ b/docs/slides/diagram-transfer-flow.svg
@@ -0,0 +1,88 @@
+
diff --git a/docs/slides/mx-rl-integration-slides.html b/docs/slides/mx-rl-integration-slides.html
new file mode 100644
index 00000000..8b21942e
--- /dev/null
+++ b/docs/slides/mx-rl-integration-slides.html
@@ -0,0 +1,790 @@
+
+
+
+
+
+ModelExpress RL Weight Update Integration
+
+
+
+
+
+
+
+
+
+
+
+ NVIDIA NIXL
+ ModelExpress
+ RL Post-Training
+
+
ModelExpress for RL Weight Updates
+
+ Extending GPU-to-GPU RDMA transfers from inference scaling to the training→inference refit boundary in reinforcement learning post-training.
+
+
+ Target frameworks: NeMo RL · verl · PRIME-RL
+
+
April 2026 — Integration Design Overview
+
+
+
+
+
The Problem
+
The Weight Sync Bottleneck
+
+
+ On-policy RL (GRPO, PPO, DAPO) alternates between rollout generation on inference GPUs and gradient updates on training GPUs. After every training step, updated weights must reach inference before the next rollout — this refit phase stalls both sides.
+
+
+
+
Wall-clock time breakdown (illustrative)
+
+
Rollout
+
Rew
+
Train
+
REFIT
+
+
▲ Up to 30–40% of wall-clock for large models
+
+
+
Current refit latency (70B-class model, multi-node)
+
+
+
Filesystem (PRIME-RL)
+
~20s+
+
+
+
NCCL Broadcast (NeMo RL)
+
~10s
+
+
+
ZMQ IPC (NeMo RL, co-loc)
+
~3-5s
+
+
+
MX RDMA P2P (target)
+
~5s
+
+
+
+
+
+
+
The Solution
+
ModelExpress for Training→Inference Refit
+
+
+ Extend MX from inference-to-inference P2P to the training→inference boundary. Training workers register updated weights with NIXL, publish metadata to the MX Server, and RDMA-WRITE directly into inference GPU memory — bypassing CPU, disk, and collective overheads.
+
+
+
+
+
+
+
Training Workers
+
FSDP2 / Megatron
+
WeightExtractor
+
MxTrainingPublisher
+
NIXL Agent per GPU
+
+
⟶
+
+
MX Server
+
gRPC Metadata Coord
+
Redis / K8s CRD
+
Version Tracking
+
+
⟶
+
+
Inference Workers
+
vLLM / SGLang
+
MxRefitReceiver
+
NIXL Agent per GPU
+
+
+
+ RDMA WRITE — GPU VRAM → GPU VRAM — bypasses CPU & disk
+
+
+
+
+
+
+
~5s
+
Llama-3.3-70B (140 GB)
+
+
+
~15s
+
DeepSeek-V3 MoE (681 GB)
+
+
+
0
+
CPU staging required
+
+
+
Auto
+
IPC → RDMA → TCP fallback
+
+
+
+
+
+
+
Architecture
+
Component Deep-Dive
+
+
+
+
+
+
+
+
+
Integration Map
+
Three Frameworks, One Abstraction
+
+
+ A framework-agnostic WeightSyncBackend decouples the transfer mechanism from each framework's orchestration and sync policy.
+
+
+
+
+
+
Dimension
+
NeMo RL
+
verl
+
PRIME-RL
+
+
+
+
Training Backend
DTensor / Megatron
FSDP / FSDP2 / Megatron
FSDP2 (EP/CP)
+
Inference Backend
vLLM, SGLang
vLLM, SGLang, HF
vLLM
+
Current Sync
ZMQ IPC / HTTP / NCCL
NCCL / CheckpointEngine
Filesystem + HTTP
+
MX Insertion Point
refit_policy_generation()
CheckpointEngine ABC
Orchestrator relay
+
Primary Benefit
RDMA replaces ZMQ/NCCL
Multi-node sans filesystem
Eliminates disk I/O
+
+
+
+
+
+
NeMo RL
+
New branch alongside ZMQ IPC and NCCL in refit function. Bucket-streamed transfer maps to MX publish. Ray actor integration.