diff --git a/docs/MX_RL_OVERVIEW.md b/docs/MX_RL_OVERVIEW.md new file mode 100644 index 00000000..d3f01aa3 --- /dev/null +++ b/docs/MX_RL_OVERVIEW.md @@ -0,0 +1,389 @@ +# ModelExpress for RL Post-Training — Design Overview + +**Last Updated**: April 2026 + +This document explains how ModelExpress (MX) accelerates reinforcement learning (RL) post-training by replacing slow weight transfer mechanisms with GPU-to-GPU RDMA. It covers the general design, the PRIME-RL proof of concept (working), and the verl integration plan. + +--- + +## The Problem: Weight Sync in RL + +RL post-training runs a continuous loop: generate text (inference) → score it → update the model (training) → sync the updated weights back to inference → repeat. The weight sync step is a bottleneck: + +| Method | How | Time (3 GB model) | Limitation | +|--------|-----|-------------------|-----------| +| Filesystem | Serialize to disk → read back | 30-60s | Disk I/O, serialization | +| NCCL broadcast | Collective over network | 2-8s | Requires static groups, `max_async_level=1` | +| **MX + NIXL RDMA** | **GPU reads directly from GPU** | **<1s** | **None for async training** | + +ModelExpress eliminates the serialization-to-disk bottleneck while preserving async training capability. NCCL forces synchronous operation; the filesystem is slow. MX + NIXL gives both speed and flexibility. + +--- + +## How ModelExpress Works for RL + +### Architecture + +```mermaid +sequenceDiagram + participant T as Trainer GPU + participant M as MX Server
(gRPC + Redis) + participant I as Inference GPU + + Note over T: 1. optimizer.step()
weights updated in VRAM + + T->>M: 2. publish_weights()
tensor addrs + NIXL metadata via gRPC + + I->>M: 3. poll_for_source()
"any new weights?" + M-->>I: 4. get_metadata()
addrs + NIXL connection info + + Note over T,I: 5. NIXL RDMA READ — GPU-to-GPU data transfer
(inference reads from trainer's VRAM, CPU not involved) + T-->>I: weight bytes (RDMA) + + Note over I: 6. model.load_weights()
inference applies weights +``` + +**MX Server** stores only metadata — tensor names, GPU memory addresses, NIXL agent connection info, version tracking. It never touches weight data. The bulk transfer is a one-sided RDMA read between GPUs. + +### Client Library + +Two classes in `modelexpress_client/python/modelexpress/`: + +**`MxTrainingPublisher`** (trainer side): +```python +publisher = MxTrainingPublisher("trainer-rank-0", device_id=0, mx_server_url="mx-server:8001") +publisher.initialize(model_name="Qwen/Qwen2.5-1.5B") + +# After optimizer.step(): +publisher.publish_weights(model.state_dict(), step=training_step) +publisher.mark_ready() +``` + +- Registers GPU tensors with NIXL (once, on first call — addresses are stable across steps) +- Publishes tensor metadata + NIXL agent info to MX Server via gRPC +- Marks version as READY so inference can discover it + +**`MxRefitReceiver`** (inference side): +```python +receiver = MxRefitReceiver("inference-rank-0", device_id=0, mx_server_url="mx-server:8001") +receiver.initialize() + +source = receiver.poll_for_source(model_name="Qwen/Qwen2.5-1.5B") +for name, tensor in receiver.receive_weights_scratch(source): + ... # feed into model.load_weights() +``` + +- Queries MX Server for available weight sources +- Allocates scratch GPU buffers matching the source tensor layout +- NIXL RDMA reads weight data directly from the trainer's GPU +- Yields `(name, tensor)` pairs for the inference engine's weight loader + +### The Scratch-Buffer Approach + +RL trainers publish weights in HuggingFace format (339 separate tensors for Qwen2.5-1.5B). Inference engines like vLLM use fused tensors internally (198 parameters — Q/K/V merged into `qkv_proj`, gate/up merged into `gate_up_proj`). Names and shapes don't match. + +Solution: RDMA into temporary scratch buffers matching the trainer's layout, then feed through `model.load_weights()` which handles the name mapping and tensor fusion. The RDMA layer stays simple (just move bytes); the inference engine handles semantics. + +--- + +## PRIME-RL POC (Working) + +### What is PRIME-RL? + +An async-first RL framework by PrimeIntellect with three separate processes: +- **Trainer** — FSDP2 training (GPU pod) +- **Orchestrator** — Coordination, scoring (CPU pod) +- **Inference** — vLLM rollout generation (GPU pod) + +No Ray dependency. Raw Kubernetes pods. + +### Architecture + +```mermaid +graph LR + subgraph cluster["GKE GB200 Cluster"] + direction TB + + subgraph control["Control Plane · CPU"] + direction LR + orch["Orchestrator"] + mx["MX Server + Redis"] + end + + subgraph gpus["GPU Pods · 4x GB200 each"] + direction LR + + subgraph tp["Trainer Pod"] + direction TB + fsdp["FSDP2 Training"] + pub["MX Publisher"] + nt(["NIXL Agent"]) + fsdp --> pub --> nt + end + + subgraph ip["Inference Pod"] + direction TB + vllm["vLLM Server"] + recv["MX Receiver"] + ni(["NIXL Agent"]) + ni --> recv --> vllm + end + end + + orch -. "HTTP rollouts" .-> vllm + orch -. "HTTP update_weights" .-> recv + pub -- "gRPC publish" --> mx + recv -- "gRPC discover" --> mx + nt <== "RDMA RoCE · 261-330 Gbps" ==> ni + end + + style cluster fill:#1a1a2e,stroke:#16213e,color:#e0e0e0 + style control fill:#1a1a2e,stroke:#533483,color:#e0e0e0 + style gpus fill:#0f3460,stroke:#533483,color:#e0e0e0 + style tp fill:#162447,stroke:#533483,color:#e0e0e0 + style ip fill:#162447,stroke:#533483,color:#e0e0e0 + style fsdp fill:#533483,stroke:#e94560,color:#fff + style vllm fill:#533483,stroke:#e94560,color:#fff + style orch fill:#533483,stroke:#e94560,color:#fff + style pub fill:#1b5e20,stroke:#4caf50,color:#fff + style recv fill:#1b5e20,stroke:#4caf50,color:#fff + style mx fill:#1b5e20,stroke:#4caf50,color:#fff + style nt fill:#2e7d32,stroke:#66bb6a,color:#fff + style ni fill:#2e7d32,stroke:#66bb6a,color:#fff +``` + +Green = ModelExpress / NIXL components. Purple = existing framework components. + +### Integration Summary + +| Item | Details | +|------|---------| +| **Repo** | `github.com/KavinKrishnan/prime-rl`, branch `kavink/mx-weight-broadcast` | +| **MX client branch** | `github.com/ai-dynamo/modelexpress`, branch `kavink/RL` | +| **New files in PRIME-RL** | `broadcast/modelexpress.py` (trainer), `worker/modelexpress.py` (inference), `Dockerfile.mx-arm64` | +| **Modified files** | 8 files, 93 lines added (configs, routes, factory, client helper) | +| **Cluster** | GKE DGXCloud, GB200 ARM64, 4 GPUs/node, RoCE networking | +| **Model** | Qwen/Qwen2.5-1.5B (3.55 GB BF16) | + +### How It Works + +1. Trainer runs `optimizer.step()`, gathers FSDP2 shards, calls `MxTrainingPublisher.publish_weights()` +2. Orchestrator detects new weights (via filesystem marker), tells inference to update +3. Inference calls `MxRefitReceiver.receive_weights_scratch()` — NIXL RDMA pulls weights from trainer GPU +4. Scratch tensors reshaped using safetensors header shapes, fed through `model.load_weights()` +5. vLLM resumes serving with updated model + +### Results + +| Metric | Filesystem | RDMA | Speedup | +|--------|-----------|------|---------| +| Weight update time | ~55s | <1s | **55x** | +| Transfer bandwidth | ~60 MB/s | 261-330 Gbps | ~500x | +| CPU involvement | Full | None | Eliminated | + +### Key Issues Resolved + +| Issue | Root Cause | Fix | +|-------|-----------|-----| +| `NIXL_ERR_NOT_ALLOWED` | Wrong `UCX_TLS` value | Match TRT-LLM: `self,sm,rc,cuda_copy,gdr_copy,tcp` | +| 800 KB metadata blob | Re-registering tensors every step | Register once, reuse cached metadata | +| `REMOTE_DISCONNECT` | UCX 1.18 + missing IMEX channels | UCX 1.20 + DRA `compute-domain-channel` claims | +| Tensor name mismatch | HF names vs vLLM fused names | Scratch-buffer approach + `model.load_weights()` | +| Shape assertion | 1D scratch vs 2D expected | Read shapes from safetensors header | + +### Cluster Config (GCP GB200) + +```yaml +hostNetwork: true +privileged: true +resourceClaims: + - name: compute-domain-channel + resourceClaimTemplateName: kavin-compute-domain-channel +env: + UCX_TLS: "self,sm,rc,cuda_copy,gdr_copy,tcp" + UCX_IB_GID_INDEX: "3" + OMPI_MCA_pml: "ob1" +volumes: + - /dev/infiniband (hostPath) +``` + +UCX 1.20+, IMEX channels via DRA, and `/dev/infiniband` are all required for cross-node GPU RDMA. + +--- + +## verl Integration (In Progress) + +### What is verl? + +A production-grade RL framework by ByteDance that uses Ray for orchestration. Supports FSDP, Megatron, vLLM, SGLang, TRT-LLM. Has a `CheckpointEngine` plugin system for weight transfer. + +### Architecture + +```mermaid +graph LR + subgraph cluster["Ray Cluster · GKE GB200"] + direction TB + + subgraph driver["Driver · CPU"] + task["TaskRunner"] + mgr["CheckpointEngine Manager"] + end + + subgraph trainer_wg["Trainer WorkerGroup · GPU"] + direction LR + tw0["Worker 0
FSDP2 + MX CE"] + tw1["Worker 1
FSDP2 + CE"] + tw2["Worker 2
FSDP2 + CE"] + tw3["Worker 3
FSDP2 + CE"] + end + + subgraph rollout_wg["Rollout Replicas · GPU"] + direction LR + rw0["CE Worker 0"] + rw1["CE Worker 1"] + rw2["CE Worker 2"] + rw3["CE Worker 3"] + vs["vLLM Server"] + end + + subgraph mx_svc["MX Server + Redis · CPU"] + mx["gRPC Metadata Broker"] + end + + task --> mgr + mgr -. "ray.get" .-> tw0 + mgr -. "ray.get" .-> rw0 + tw0 -- "gRPC publish" --> mx + rw0 -- "gRPC discover" --> mx + tw0 <== "NIXL RDMA" ==> rw0 + tw1 <== "NIXL RDMA" ==> rw1 + tw2 <== "NIXL RDMA" ==> rw2 + tw3 <== "NIXL RDMA" ==> rw3 + rw0 -. "CUDA IPC" .-> vs + end + + style cluster fill:#1a1a2e,stroke:#16213e,color:#e0e0e0 + style driver fill:#1a1a2e,stroke:#533483,color:#e0e0e0 + style trainer_wg fill:#0f3460,stroke:#533483,color:#e0e0e0 + style rollout_wg fill:#0f3460,stroke:#533483,color:#e0e0e0 + style mx_svc fill:#1a1a2e,stroke:#4caf50,color:#e0e0e0 + style task fill:#533483,stroke:#e94560,color:#fff + style mgr fill:#1b5e20,stroke:#4caf50,color:#fff + style tw0 fill:#1b5e20,stroke:#4caf50,color:#fff + style tw1 fill:#162447,stroke:#533483,color:#e0e0e0 + style tw2 fill:#162447,stroke:#533483,color:#e0e0e0 + style tw3 fill:#162447,stroke:#533483,color:#e0e0e0 + style rw0 fill:#1b5e20,stroke:#4caf50,color:#fff + style rw1 fill:#2e7d32,stroke:#66bb6a,color:#fff + style rw2 fill:#2e7d32,stroke:#66bb6a,color:#fff + style rw3 fill:#2e7d32,stroke:#66bb6a,color:#fff + style vs fill:#533483,stroke:#e94560,color:#fff + style mx fill:#1b5e20,stroke:#4caf50,color:#fff +``` + +Green = MX checkpoint engine components. Purple = existing verl/Ray components. The `CheckpointEngineManager` coordinates the MX CE workers on both trainer and rollout sides. Each trainer-rollout rank pair transfers via NIXL RDMA. Received weights reach vLLM via CUDA IPC through the existing `ServerAdapter`. + +### Why It's Easier Than PRIME-RL + +verl already has: +- **`CheckpointEngine` ABC** with `send_weights` / `receive_weights` — just implement a new backend +- **Existing NIXL engine** (`nixl_checkpoint_engine.py`) as a reference implementation +- **Bucketed transfers** that preserve tensor names and shapes — no scratch-buffer approach needed +- **`@CheckpointEngineRegistry.register("mx")`** — one decorator to plug in + +### Integration Summary + +| Item | Details | +|------|---------| +| **Repo** | `github.com/KavinKrishnan/verl`, branch `kavink/mx-checkpoint-engine` | +| **New file** | `verl/checkpoint_engine/mx_checkpoint_engine.py` (461 lines) | +| **Modified files** | 2 files, 8 lines (imports + config comment) | +| **Total** | 469 new lines, 2 modified | + +### `MxCheckpointEngine` Design + +Registered as `@CheckpointEngineRegistry.register("mx")`. Implements the `CheckpointEngine` ABC: + +- **`prepare()`** — Allocate send/recv GPU buckets, register with NIXL, create MX client +- **`build_topology()`** — Trainer rank 0 is the source, all rollout ranks connect via MX Server +- **`init_process_group()`** — Trainer adds rollout agents; rollouts add trainer agent +- **`send_weights()`** — Pack tensors into GPU buckets, send bucket metadata via ZMQ, make available for RDMA read +- **`receive_weights()`** — Receive metadata via ZMQ, NIXL RDMA read from trainer's bucket, yield `(name, tensor)` pairs +- **`finalize()`** — Cleanup connections, deregister memory + +Uses a **star topology** (trainer → all rollouts) instead of the NIXL engine's ring. The MX Server enables future pipeline replication where rollouts become sources. + +### Config + +```yaml +actor_rollout_ref: + rollout: + checkpoint_engine: + backend: "mx" + engine_kwargs: + mx_server_url: "modelexpress-server:8001" + model_name: "Qwen/Qwen2.5-1.5B" +``` + +--- + +## Comparison: PRIME-RL vs verl Integration + +| Aspect | PRIME-RL | verl | +|--------|---------|------| +| Plugin system | Custom broadcast/worker extensions | `CheckpointEngine` registry | +| Existing NIXL support | None (built from scratch) | Full NIXL engine as reference | +| Lines of code | ~350 new + 93 modified | ~460 new + 8 modified | +| Ray dependency | None (raw K8s pods) | Ray actors, placement groups | +| Weight format | HF names → scratch buffers → `load_weights()` | Bucketed with shapes preserved | +| Tensor shape issue | Required safetensors header reading | Not an issue (bucket metadata carries shapes) | +| Colocated mode | N/A (always disaggregated) | `naive` engine handles colocated; MX for disaggregated | + +--- + +## Future: Pipeline Replication + +Current design uses a star topology (trainer → all rollouts). At scale, the trainer's NIC becomes a bottleneck. [TensorHub](https://arxiv.org/abs/2604.09107v1) (ByteDance, April 2026) demonstrates **pipeline replication**: after a rollout receives weights, it publishes itself as a source. New rollouts pull from the nearest/least-loaded replica, creating a bandwidth-amplifying DAG. + +MX Server already supports multiple sources per model — implementing pipeline replication is a client-side change: +1. After RDMA receive, rollout calls `publish_weights()` on MX Server +2. New rollouts call `poll_for_source()` which returns the nearest available replica +3. MX Server load-balances across all replicas + +This is prioritized as P1 in our roadmap. + +--- + +## Repository Map + +### ModelExpress client (`kavink/RL` branch) + +```text +modelexpress_client/python/modelexpress/ +├── training_publisher.py # MxTrainingPublisher — trainer-side publish +├── refit_receiver.py # MxRefitReceiver — inference-side RDMA receive +├── nixl_transfer.py # NixlTransferManager — NIXL agent lifecycle +├── client.py # MxClient — gRPC client to MX Server +└── __init__.py # Exports MxTrainingPublisher, MxRefitReceiver +``` + +### PRIME-RL integration (`kavink/mx-weight-broadcast` branch) + +```text +src/prime_rl/ +├── trainer/rl/broadcast/modelexpress.py # ModelExpressWeightBroadcast +├── inference/vllm/worker/modelexpress.py # MxWeightUpdateWorker +├── inference/vllm/server.py # /init_mx_broadcaster route (+13 lines) +├── orchestrator/orchestrator.py # elif "modelexpress" branch (+7 lines) +├── utils/client.py # init_mx_broadcast() (+34 lines) +└── configs/ # MxWeightBroadcastConfig (+32 lines) +``` + +### verl integration (`kavink/mx-checkpoint-engine` branch) + +```text +verl/ +├── checkpoint_engine/mx_checkpoint_engine.py # MxCheckpointEngine +├── checkpoint_engine/__init__.py # Optional import (+7 lines) +└── workers/config/rollout.py # "mx" in backend comment (+1 line) +``` diff --git a/docs/RL/NEMORL_MX_OVERVIEW.md b/docs/RL/NEMORL_MX_OVERVIEW.md new file mode 100644 index 00000000..b3fe90cc --- /dev/null +++ b/docs/RL/NEMORL_MX_OVERVIEW.md @@ -0,0 +1,769 @@ +# ModelExpress × NeMo-RL — Design + Validation Overview (v2) + +**Last Updated**: May 8, 2026 +**Status**: **End-to-end NIXL RDMA refit working on real GB200**, prototyped on `kavink/nemo_rl_moe` (MX) + `kavink/mx_integration` ([NVIDIA-NeMo/RL](https://github.com/NVIDIA-NeMo/RL)). 4 ranks × 2 cycles × toy tensors verified byte-correct (sentinels match). 4 ranks × 1.6 GB Qwen3-30B-A3B-shaped tensors land in 11–16 ms each. **4 ranks × 2 cycles × real `torch.distributed.tensor.DTensor` (Shard(0) FSDP placement) verified byte-correct AND verified that the receiver's reconstructed shape registry reports the correct un-sharded `global_shape` plus per-rank `local_shard_range` for every tensor.** 15/15 unit tests passing. arm64 NemoRL overlay image (`nvcr.io/nvidian/dynamo-dev/nemo-rl:kavink-v2`) built and smoke-tested but not yet pushed to a registry, so the actual Ray-orchestrated NemoRL training loop on Qwen3 hasn't been driven yet — that's the next milestone, gated only on image push + a K8s manifest. + +This document is the technical companion to the upstream PR. It covers: + +1. Why we built this (lessons from PrimeRL #2389 + the TensorHub paper + Composer 2 router replay). +2. The v2 design — 4 pillars: rank-to-rank publish, tree scale-out, MoE expert filter, explicit shape registry. +3. The Python API surface a NemoRL caller sees. +4. The full file inventory across the two branches. +5. Where the running MX server has gaps and how we work around them today. +6. What was tested vs what is still on paper. +7. End-to-end deployment recipe for the next session. + +--- + +## 1. Motivation + +NeMo-RL today has two weight-sync paths — **`update_weights_via_ipc_zmq`** (CUDA IPC handles over a ZMQ socket; only valid in colocated/hybrid mode) and **`update_weights_from_collective`** (NCCL `broadcast` from rank 0 in non-colocated mode). The NCCL path has three blockers for the workloads we care about: + +| Problem | Where it bites | +|---|---| +| **`tensor.full_tensor()` allgather on every refit** | `dtensor_policy_worker.py:1822-1834` (`broadcast_weights_for_collective`). On a 30B-MoE with FSDP=4 this is ~120 GB through rank-0's NIC per refit. | +| **Static NCCL group** | NCCL barrier locks the trainer + all rollout replicas into a fixed world. Spot/elastic rollout, mid-run rebalancing, cross-DC — all blocked. | +| **No MoE awareness** | Every rank receives every expert weight, even if its EP shard only needs 1/8th of them. Composer 2 reports this as the dominant refit cost on Kimi K2.5 (1.04T / 32B active). | + +PrimeRL PR [#2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389) is the closest framework analog. We live-debugged it on GB200 in early May and learned two things the hard way: + +1. **Cross-subnet full-mesh in `TransportPlan` ≠ routable.** GCP GB200's four `mlx5_N` NICs each sit on their own L3 subnet (`rdma-0..rdma-3`); the full-mesh `add_remote_agent` loop hits `NIXL_ERR_REMOTE_DISCONNECT` whenever (trainer rank N → inference rank M ≠ N). For the 1-to-1 dp-only layout that NeMo-RL also uses, **same-rank-only writes** are both topologically correct and 3× cheaper in NIXL connection count. + +2. **vLLM workers don't unpublish-on-death and don't heartbeat.** Each orchestrator restart leaves stale `READY` rows in MX Redis; subsequent `add_remote_agent` calls choke on the dead rows with `NIXL_ERR_NOT_ALLOWED`. The fix is `(worker_rank, max(updated_at))` dedup at read time, plus a real heartbeat on the publisher side. + +The TensorHub paper ([arXiv 2604.09107v1](https://arxiv.org/pdf/2604.09107v1)) gave us the production-quality framing — **Reference-Oriented Storage**, mutability contract, retention protocol, and crucially **pipeline replication** (a receiver becomes a source for the next receiver, building an expanding DAG that scales bandwidth with the number of active clients). Cursor's Composer 2 technical report adds **router replay + per-expert delta compression** as the MoE-specific shape we need. + +**v2 = NemoRL's existing IPC/NCCL-style API + every learning above.** + +--- + +## 2. Comparison to existing NeMo-RL paths + +| Property | `update_weights_via_ipc_zmq` | `update_weights_from_collective` (NCCL) | **`update_weights_via_mx` (this PR)** | +|---|---|---|---| +| Cross-node | ❌ (colocated only) | ✅ | ✅ | +| Full-mesh allgather | ❌ | ✅ — `tensor.full_tensor()` on rank 0 | **❌ — `tensor.to_local()` on every rank** | +| Trainer NIC bottleneck | n/a | yes — single rank-0 funnel | **no — N parallel rank-N → rank-N pairs** | +| MoE expert filtering | none | none | **first-class — owned/needed expert IDs in metadata** | +| Tree fan-out for cold-start replicas | none | none | **TensorHub pipeline replication** | +| Cross-DC | ❌ | ❌ | designed (P3, not yet wired) | +| Elastic rollouts | ❌ | ❌ | ✅ — NIXL connections are dynamic, no static world | +| Heartbeat / liveness | n/a | n/a | ✅ — `HeartbeatThread` per worker | +| Versioning / freshness | implicit via NCCL barrier order | implicit | **explicit via `version: int` on every publish** | +| Mutability contract | none | none | **`set_status(STALE)` drains in-flight readers (TensorHub-style)** | +| Backward-compat default | n/a | yes | yes — opt-in via `cluster.weight_sync.method: "mx"` | + +--- + +## 3. Architecture + +``` + ┌────────────────────────────────────────────┐ + │ MX Server (Rust + Redis) │ + │ │ + │ publish_metadata / list_sources / │ + │ get_metadata / update_status │ + │ │ + │ storage layout (Redis HASH): │ + │ mx:source:{16-hex} │ + │ __attributes__ → SourceAttributesJson │ + │ (incl. extra_params) │ + │ → worker_rank │ + │ mx:source:{16-hex}:{worker_uuid} │ + │ → WorkerRecordJson │ + └─────┬──────────────────────────┬───────────┘ + │ gRPC │ gRPC + publish + register query / list + │ │ + ┌────────────────────┴───┐ ┌────────────┴──────────┐ + │ Trainer ranks │ │ Inference ranks │ + │ (FSDP2 / DTensor) │ │ (vLLM, EP-sharded) │ + │ │ │ │ + │ rank N → publish │ │ rank N → discover │ + │ its local DTensor │ │ same-rank source │ + │ shard, with │ RDMA │ via picker (filter │ + │ per-tensor │ ◄────── │ by worker_rank, │ + │ placement info │ WRITE │ dedup by latest │ + │ │ │ updated_at) │ + │ Each rank registers │ │ │ + │ buffers with NIXL │ │ After receive: each │ + │ ONCE (addresses are │ │ rank registers itself │ + │ stable across steps) │ │ as a NEW source for │ + │ │ │ subsequent receivers │ + │ HeartbeatThread keeps │ │ (TensorHub pipeline- │ + │ updated_at fresh │ │ replication trick) │ + └────────────────────────┘ └────────────────────────┘ +``` + +--- + +## 4. The four design pillars + +### 4.1 Pillar 1 — Rank-to-rank publish (no allgather) + +**Trainer side.** Each FSDP/EP rank publishes only its **local DTensor shard**, never `tensor.full_tensor()`. The placement (which axis is sharded, which range of indices this rank holds) travels in the per-tensor metadata. + +```python +# nemo_rl/models/policy/workers/dtensor_policy_worker.py +@torch.no_grad() +@wrap_with_nvtx_name("dtensor_policy_worker/stream_weights_via_mx") +def stream_weights_via_mx(self, *, version: int, mx_config: Any) -> None: + if not hasattr(self, "_mx_publisher") or self._mx_publisher is None: + self._mx_publisher = build_v2_publisher( + rank=self.rank, + device_id=self.local_device_index, + fsdp_world_size=self.world_size, + tp_world_size=self.tp_size or 1, + pp_world_size=self.pp_size or 1, + ep_world_size=self.ep_size or 1, + mx_config=mx_config, + ) + self._mx_publisher.initialize(model_name=self.model_name, dtype=str(self.dtype).removeprefix("torch.")) + self._mx_expert_layout = detect_moe_expert_layout( + self.model, ep_world_size=self.ep_size or 1, rank=self.rank, + ) if mx_config.moe_expert_filter else {} + + self._mx_publisher._registry.clear() + self._mx_publisher._registered_tensors.clear() + for name, tensor in self.model.state_dict().items(): + local = tensor.to_local() if isinstance(tensor, DTensor) else tensor # ← key: NO allgather + local = local.to(self.dtype, non_blocking=True).contiguous() + expert_info = self._mx_expert_layout.get(name) + self._mx_publisher.add_tensor( + name=name, + tensor=local, + is_expert=expert_info is not None, + expert_axis=expert_info[0] if expert_info else 0, + owned_expert_ids=expert_info[1] if expert_info else set(), + ) + # Override the descriptor's global_shape from the DTensor view so the + # receiver knows the un-sharded shape. The NIXL-registered buffer is + # still the local shard. + if isinstance(tensor, DTensor): + self._mx_publisher._registry[-1].global_shape = tuple(int(s) for s in tensor.shape) + + self._mx_publisher.publish(version=int(version)) + self._mx_publisher.mark_ready() # ← starts HeartbeatThread +``` + +**Cost vs allgather pattern (Qwen3-30B-A3B FSDP=4)**: +- v1 / NCCL: 4 ranks each allgather → 4× full model materialized in VRAM at peak; rank 0's NIC ships full model to every inference rank → ~120 GB through one NIC. +- v2: each rank holds only its 1/4 shard; each rank's NIC ships its 1/4 → 4 NICs in parallel → 4× the aggregate bandwidth. + +### 4.2 Pillar 2 — Tree scale-out (TensorHub pipeline replication) + +After an inference rank finishes receiving its slice, it **becomes a source** by re-registering its already-NIXL-registered receive buffers and publishing as `inference_replica`. Subsequent same-rank receivers can pull from it instead of contending on the trainer's NIC. + +```python +# modelexpress/nemo_rl_v2.py — MxV2RefitReceiver +def publish_self_as_source(self, *, version: int, model_name: str) -> str | None: + identity = p2p_pb2.SourceIdentity( + model_name=model_name, mx_source_type=p2p_pb2.MX_SOURCE_TYPE_WEIGHTS, + backend_framework=p2p_pb2.BACKEND_FRAMEWORK_UNKNOWN, + dtype="bfloat16", + extra_parameters={ + "role": ROLE_INFERENCE_REPLICA, + "mx_v2": "1", + "worker_rank": str(self._worker_rank), + "training_step": str(int(version)), + "training_framework": "nemo_rl", + }, + ) + worker_meta = p2p_pb2.WorkerMetadata( + worker_rank=self._worker_rank, + nixl_metadata=nixl.nixl_metadata, + tensors=[p2p_pb2.TensorDescriptor(name=d.name, addr=d.addr, size=d.size, + device_id=d.device_id, dtype=d.dtype) + for d in nixl.tensor_descriptors], + status=p2p_pb2.SOURCE_STATUS_READY, + agent_name=self._receiver._agent_name, + ) + return client.publish_metadata(identity=identity, worker=worker_meta, + worker_id=self._receiver._worker_id) +``` + +The picker prefers `trainer` over `inference_replica` when both are visible (the trainer is always authoritative), then breaks ties on `max(updated_at)`. This means: +- First receiver → pulls from trainer. +- Second receiver (slow / cold-start / restart) → may pull from the first receiver if trainer is busy. (Today: picker just picks freshest trainer; the load-balancing improvement is a follow-up.) + +### 4.3 Pillar 3 — MoE expert filtering + +Each tensor descriptor carries `is_expert: bool`, `expert_axis: int`, and the publisher's `owned_expert_ids: set[int]`. An EP-sharded inference rank's `pick_best_source` accepts an optional `needed_experts_per_layer` filter and rejects candidates that don't cover all needed experts. + +```python +# Receiver side +chosen = receiver.pick_best_source( + candidates, + needed_experts_per_layer={5: {72, 73, 74, 75, 76, 77}}, # what THIS rank needs +) +``` + +Composer 2 extends this with a **dirty-experts bitmap** ("only experts with non-zero gradient since last refit"). The MX-side surface for that is designed (`set_dirty_experts`, `get_dirty_experts` in the design doc) but not yet implemented; v0 refits all owned experts. + +### 4.4 Pillar 4 — Explicit shape registry / mutability contract + +Every published tensor carries a `TensorDescriptorV2` (placement kind, shard axis, local shard range, expert axis, owned expert IDs). Receivers consult these to know exactly what to expect and where each shard fits in the global tensor. + +The mutability contract follows TensorHub §3.2: +- The trainer **MUST NOT** mutate `tensor` between `publish_metadata(version=v)` and `set_status(STALE)` for that version. +- `set_status(STALE)` is intended to block until in-flight RDMA reads complete. (Today it's a no-op on the server side; the proper drain semantics are a follow-up — design only.) +- Inference workers commit to the same contract for any version they hold. + +--- + +## 5. Public Python API + +Three new symbols on the MX side: + +```python +from modelexpress import ( + MxV2TrainingPublisher, # trainer-side wrapper + MxV2RefitReceiver, # inference-side wrapper + TrainerWorldLayout, # (fsdp, tp, pp, ep) descriptor +) +from modelexpress.shape_descriptors import ( + TensorDescriptorV2, + describe_tensor, # DTensor → wire format + even_expert_owner_map, +) +``` + +One new module on the NemoRL side: + +```python +from nemo_rl.distributed.mx_helpers import ( + MxConfig, # parsed from cfg.cluster.weight_sync + build_v2_publisher, # convenience constructor + NIC pin + build_v2_receiver, + pin_local_nic, + collect_named_local_shards, + detect_moe_expert_layout, +) +``` + +Two new abstract methods on the existing interfaces: + +```python +# nemo_rl/models/policy/interfaces.py +class ColocatablePolicyInterface(PolicyInterface): + def stream_weights_via_mx(self, *, version: int, mx_config: Any) -> list[ray.ObjectRef]: + raise NotImplementedError("...") + +# nemo_rl/models/generation/interfaces.py +class GenerationInterface(ABC): + def update_weights_via_mx(self, *, version: int, mx_config: Any) -> list[ray.ObjectRef]: + raise NotImplementedError("...") +``` + +One new branch in `algorithms/grpo.py`: + +```python +# nemo_rl/algorithms/grpo.py::refit_policy_generation +elif weight_sync_method == "mx": + if mx_config is None or not getattr(mx_config, "enabled", False): + raise RuntimeError( + "weight_sync_method='mx' requires an enabled MxConfig " + "(cfg.cluster.weight_sync.method='mx', .enabled=True)" + ) + version = int(refit_version) if refit_version is not None else 0 + futures_train = policy.stream_weights_via_mx(version=version, mx_config=mx_config) + futures_inference = policy_generation.update_weights_via_mx(version=version, mx_config=mx_config) + ray.get(futures_train) + results = ray.get(futures_inference) + update_success = all(result for result in results if result is not None) +``` + +`MxConfig` knobs (parsed from `cfg.cluster.weight_sync`): + +```python +@dataclass +class MxConfig: + enabled: bool = False # master switch + mx_server_url: str = "modelexpress-server:8001" + timeout_seconds: float = 300.0 + same_rank_only: bool = True # ← required on multi-subnet RDMA fabrics (GB200 / EFA) + tree_scale_out: bool = True # ← receivers republish as inference_replica + moe_expert_filter: bool = True # ← only request owned experts + register_self_buffers: list[str] = [] + nic_pin: str = "auto" # "auto" | "off" | "mlx5_X" + retain_latest_k: int = 1 # TensorHub-style retention (designed; not enforced server-side yet) +``` + +Worked example for Qwen3-30B-A3B: + +```yaml +cluster: + weight_sync: + method: "mx" + enabled: true + mx_server_url: "modelexpress-server.kavin.svc.cluster.local:8001" + timeout_seconds: 300.0 + same_rank_only: true # GB200/EFA — keep ON + tree_scale_out: true + moe_expert_filter: true + nic_pin: "auto" + retain_latest_k: 1 +``` + +--- + +## 6. File inventory across the two branches + +### `kavink/nemo_rl_moe` (off `kavink/RL` in `NVIDIA-Model-Optimizer/modelexpress`) + +``` +modelexpress_common/proto/p2p.proto M +7 added SourceIdentity identity = 5 to GetMetadataResponse +modelexpress_client/python/modelexpress/__init__.py M +8 re-export MxV2TrainingPublisher / MxV2RefitReceiver / TrainerWorldLayout +modelexpress_client/python/modelexpress/p2p_pb2.py M ±40 regenerated +modelexpress_client/python/modelexpress/p2p_pb2_grpc.py M ±4 regenerated +modelexpress_client/python/modelexpress/shape_descriptors.py A +277 TensorDescriptorV2, describe_tensor, even_expert_owner_map, codecs +modelexpress_client/python/modelexpress/nemo_rl_v2.py A +752 MxV2TrainingPublisher, MxV2RefitReceiver, TrainerWorldLayout, V2SourceCandidate +modelexpress_client/python/scripts/v2_moe_e2e_demo.py A +253 standalone GB200 cluster demo +modelexpress_client/python/tests/test_v2_shape_registry.py A +187 8 unit tests +modelexpress_client/python/tests/test_v2_source_picker.py A +469 7 unit tests (mocked NIXL/gRPC) +modelexpress_server/src/metadata_backend.rs M +9 ModelMetadataRecord.identity field +modelexpress_server/src/metadata_backend/redis.rs M +36 SourceAttributesJson.extra_parameters + to_source_identity() +modelexpress_server/src/metadata_backend/kubernetes.rs M +5 identity: None placeholder (CRD schema bump deferred) +modelexpress_server/src/p2p_service.rs M +10 populate GetMetadataResponse.identity +modelexpress_server/src/state.rs M +1 identity: None in test fixtures +``` + +Two commits: +- `97c0e78` — client Python (publisher / receiver / shape descriptors / tests / demo / proto regen) +- `0bce4f0` — server Rust (round-trip the SourceIdentity through GetMetadata) + +### `kavink/mx_integration` (off `main` in `NVIDIA-NeMo/RL`) + +``` +nemo_rl/algorithms/grpo.py M +49 mx branch in refit_policy_generation +nemo_rl/distributed/mx_helpers.py A +250 MxConfig, build_v2_*, pin_local_nic, collect_named_local_shards, detect_moe_expert_layout +nemo_rl/models/generation/interfaces.py M +23 abstract update_weights_via_mx +nemo_rl/models/generation/vllm/vllm_backend.py M +127 VllmInternalWorkerExtension.update_weights_via_mx (NIXL receive + _load_weights) +nemo_rl/models/generation/vllm/vllm_generation.py M +20 Ray driver fan-out +nemo_rl/models/generation/vllm/vllm_worker.py M +27 Ray actor entry → collective_rpc("update_weights_via_mx", ...) +nemo_rl/models/policy/interfaces.py M +29 abstract stream_weights_via_mx (default-NotImplementedError, opt-in) +nemo_rl/models/policy/lm_policy.py M +14 Policy.stream_weights_via_mx (worker_group fan-out) +nemo_rl/models/policy/workers/dtensor_policy_worker.py M +111 DTensorPolicyWorker.stream_weights_via_mx (uses tensor.to_local()) +docker/v2_overlay/Dockerfile A +80 thin overlay over nvcr.io/nvidia/nemo-rl:v0.6.0 +``` + +One commit: `d58dca07`. + +Both branches are committed but **not pushed** (you'll do that with the appropriate auth). + +--- + +## 7. Three-tier metadata transport (server-side workaround) + +The **running** MX server in our `kavin` namespace (`nvcr.io/nvidian/dynamo-dev/modelexpress-server:latest`, started May 6) drops most string fields when echoing `WorkerMetadata` back via `GetMetadata`. Confirmed by direct gRPC introspection on `prime-rl-nixl-mx-trainer-0`: + +``` +> client.list_sources(...) instances → ok (model_name, worker_rank present on SourceInstanceRef) +> client.get_metadata(instance.mx_source_id, instance.worker_id): + found=True + worker.tensors → preserved ✅ + worker.status → preserved ✅ + worker.worker_rank → preserved ✅ + worker.nixl_metadata → preserved ✅ (the bytes blob) + worker.updated_at → preserved ✅ + worker.agent_name → '' ❌ (publisher set it, server dropped it) + worker.metadata_endpoint → '' ❌ + worker.worker_grpc_endpoint → '' ❌ + identity → NOT PRESENT ❌ (the proto field didn't exist in the running build) + identity.extra_parameters → n/a (would be empty even if present, since identity wasn't returned at all) +``` + +Because v2 metadata (`mx_v2`, `role`, `worker_rank`, `training_step`, `shape_registry`) was originally designed to live in `SourceIdentity.extra_parameters`, it can't reach the receiver via the current server. + +The v2 receiver tries **three transports in order**, falling back to the next when the previous returns empty: + +```python +# modelexpress/nemo_rl_v2.py::MxV2RefitReceiver.discover_v2_sources + +# 1) SourceIdentity.extra_parameters via meta.identity +identity = getattr(meta, "identity", None) +extra = (dict(identity.extra_parameters) if identity is not None and identity.extra_parameters else {}) + +# 2) Synthetic TensorDescriptor sidecar (the path the prototype actually uses today) +if not extra: + for td in meta.worker.tensors: + if td.name == _V2_SIDECAR_NAME and td.dtype: # _V2_SIDECAR_NAME = "__mx_v2_meta__" + try: + sidecar = json.loads(td.dtype) + if isinstance(sidecar, dict): + for k, v in sidecar.items(): + extra[k] = str(v) + except (json.JSONDecodeError, TypeError): + pass + break + +# 3) WorkerMetadata.agent_name string-encoded marker (legacy fallback) +if not extra: + agent_name = getattr(meta.worker, "agent_name", "") or "" + if agent_name.startswith("mx_v2|"): + # ... parse "mx_v2||rank=N|version=K|orig=..." +``` + +Symmetrically, the publisher writes v2 metadata into all three locations: + +```python +# modelexpress/nemo_rl_v2.py::MxV2TrainingPublisher.publish + +# Path 1: extra_parameters (forward-compat, used once Rust server populates GetMetadataResponse.identity) +def _build_identity_with_v2(step): + ident = original_build_identity(step) + ident.extra_parameters["role"] = ROLE_TRAINER + ident.extra_parameters["mx_v2"] = "1" + ident.extra_parameters["worker_rank"] = str(self._worker_rank) + ident.extra_parameters["shape_registry"] = registry_blob + ident.extra_parameters["world_layout"] = self._world_layout.encode() + return ident +self._publisher._build_identity = _build_identity_with_v2 + +# Path 2: synthetic TensorDescriptor sidecar (today's transport) +sidecar_payload = json.dumps({ + "mx_v2": "1", "role": ROLE_TRAINER, + "worker_rank": int(self._worker_rank), + "training_step": int(version), + "world_layout": self._world_layout.encode(), + "framework": "nemo_rl", +}) +def _build_tensor_protos_with_sidecar(descriptors): + protos = original_build_tensor_protos(descriptors) + protos.append(p2p_pb2.TensorDescriptor( + name="__mx_v2_meta__", addr=0, size=0, device_id=0, dtype=sidecar_payload, + )) + return protos +self._publisher._build_tensor_protos = _build_tensor_protos_with_sidecar + +# Path 3: agent_name encoding +self._publisher._agent_name = ( + f"mx_v2|{ROLE_TRAINER}|rank={self._worker_rank}|" + f"version={int(version)}|orig={original_agent_name}" +) +``` + +**The Rust server fix is committed in commit `0bce4f0` on `kavink/nemo_rl_moe`** but requires a server image rebuild + redeploy to land. Specifically: + +| Change | File | What | +|---|---|---| +| Proto: add `SourceIdentity identity = 5;` to `GetMetadataResponse` | `modelexpress_common/proto/p2p.proto` | Field already added; clients regenerated. Backward-compat (older clients ignore). | +| Storage: `SourceAttributesJson` gains `extra_parameters: HashMap` | `modelexpress_server/src/metadata_backend/redis.rs:39-72` | `#[serde(default)]` so old records read clean. | +| Storage: `SourceAttributesJson::to_source_identity()` reconstructs full `SourceIdentity` | `modelexpress_server/src/metadata_backend/redis.rs:88-104` | Used by `get_metadata`. | +| Service: populate `GetMetadataResponse.identity` from `record.identity` | `modelexpress_server/src/p2p_service.rs:170-230` | One-line wiring. | +| K8s CRD backend: `identity: None` for now (CRD schema bump separate) | `modelexpress_server/src/metadata_backend/kubernetes.rs:439-450` | v2 clients fall back to sidecar. | + +Once the new server image lands, transports collapse to Path 1 (cleanest); the sidecar TensorDescriptor stays in the wire as a no-op fallback. + +--- + +## 8. What was tested + +### 8.1 Unit tests (no GPU, no NIXL) + +```bash +cd ~/Work/Github/MX0/modelexpress # or upstream equivalent after push +python3 -m pytest modelexpress_client/python/tests/test_v2_shape_registry.py \ + modelexpress_client/python/tests/test_v2_source_picker.py -v +``` + +Result: **15/15 PASS**: + +``` +test_v2_shape_registry.py::test_replicate_descriptor_round_trip PASSED +test_v2_shape_registry.py::test_sharded_dtensor_local_range PASSED +test_v2_shape_registry.py::test_moe_expert_descriptor_in_registry PASSED +test_v2_shape_registry.py::test_expert_owner_map_uniform PASSED +test_v2_shape_registry.py::test_expert_owner_map_rejects_uneven PASSED +test_v2_shape_registry.py::test_expert_set_codec_round_trip PASSED +test_v2_shape_registry.py::test_decode_expert_set_handles_empty_and_whitespace PASSED +test_v2_shape_registry.py::test_registry_full_round_trip_multitensor PASSED +test_v2_source_picker.py::test_same_rank_filter_dedup_freshest PASSED +test_v2_source_picker.py::test_min_version_filter PASSED +test_v2_source_picker.py::test_non_v2_sources_ignored PASSED +test_v2_source_picker.py::test_pick_best_with_expert_filter PASSED +test_v2_source_picker.py::test_pick_best_falls_back_to_trainer PASSED +test_v2_source_picker.py::test_world_layout_round_trip PASSED +test_v2_source_picker.py::test_agent_name_fallback_when_identity_missing PASSED +``` + +The picker tests are particularly important — they assert that: +- `same_rank_only=True` correctly rejects rank-0 sources for a rank-2 receiver. +- Multiple stale `READY` rows for the same `worker_rank` collapse to the freshest one (the `(worker_rank, max(updated_at))` dedup that fixes the PrimeRL `NIXL_ERR_NOT_ALLOWED` bug class). +- Sources missing the `mx_v2=1` marker are ignored entirely (forward-compat against future v3+ clients). +- MoE expert filter rejects candidates that don't cover `needed_experts_per_layer`. +- Trainer is always preferred over `inference_replica` for same `(worker_rank, version)`. +- The agent_name fallback path correctly parses `mx_v2||rank=N|version=K|orig=...` for legacy servers. + +### 8.2 Live cluster gRPC smoke test + +Port-forwarded the running MX server from a workstation: + +```bash +kubectl -n kavin port-forward svc/modelexpress-server 18001:8001 +python3 -c " +from modelexpress import MxClient +import modelexpress.p2p_pb2 as p2p_pb2 +c = MxClient(server_url='localhost:18001') +resp = c.list_sources(status_filter=p2p_pb2.SOURCE_STATUS_READY) +print(f'total READY: {len(resp.instances)}') +for inst in resp.instances[:3]: + meta = c.get_metadata(inst.mx_source_id, inst.worker_id) + print(f' {inst.model_name} rank={inst.worker_rank} ' + f'identity_present={hasattr(meta, \"identity\")} ' + f'agent_name={meta.worker.agent_name!r}') +" +``` + +This is the test that surfaced the proto bugs in §7. Useful to re-run any time the server image changes. + +### 8.3 Live E2E on GB200 — toy scale (correctness) + +Inside the running `prime-rl-nixl-mx-trainer-0` pod (which has 4× B200 GPUs, NIXL, and reachability to MX): + +```bash +# 1) copy our v2 files into the pod (the trainer pod's MX install is older) +SRC=/home/kavink/Work/Github/MX0/modelexpress/modelexpress_client/python/modelexpress +DST=/app/.venv/lib/python3.12/site-packages/modelexpress +for f in shape_descriptors.py nemo_rl_v2.py p2p_pb2.py p2p_pb2_grpc.py __init__.py refit_receiver.py training_publisher.py; do + kubectl -n kavin cp -c trainer "$SRC/$f" prime-rl-nixl-mx-trainer-0:"$DST/$f" +done +kubectl -n kavin cp -c trainer \ + /home/kavink/Work/Github/MX0/modelexpress/modelexpress_client/python/scripts/v2_moe_e2e_demo.py \ + prime-rl-nixl-mx-trainer-0:/tmp/v2_moe_e2e_demo.py + +# 2) run the demo: 4 ranks, 8 experts (chunk=2/rank), HIDDEN=256, 2 cycles +kubectl -n kavin exec prime-rl-nixl-mx-trainer-0 -- bash -c " + cd /tmp && WORLD_SIZE=4 NUM_EXPERTS=8 N_REFIT_CYCLES=2 timeout 90 python3 v2_moe_e2e_demo.py +" +``` + +Output (abridged): + +``` +[trainer R0] published v=0 mx_source_id=393ec6709b204c80 sentinel_target=8 got=8 +[trainer R1] published v=0 mx_source_id=bf2e1ce5d3bebde6 sentinel_target=16 got=16 +[trainer R2] published v=0 mx_source_id=6b057dd75143e1db sentinel_target=24 got=24 +[trainer R3] published v=0 mx_source_id=458954a508b0c650 sentinel_target=32 got=32 + +[inference R0] picked source role=trainer src_rank=0 v=0 updated_at=1778169737062 +[inference R0] received 'model.layers.0.experts.weight' shape=(2, 256, 256) dtype=torch.bfloat16 +[inference R0] received 'model.layers.0.layer_norm.weight' shape=(256,) dtype=torch.bfloat16 +[inference R0] correctness: OK +[inference R1] picked source role=trainer src_rank=1 v=0 ... → OK +[inference R2] picked source role=trainer src_rank=2 v=0 ... → OK +[inference R3] picked source role=trainer src_rank=3 v=0 ... → OK + +# … cycle 1 with version=1 … + +[inference R1] picked source role=trainer src_rank=1 v=1 updated_at=1778169742340 ← freshness dedup picks v=1 over the still-alive v=0 +[inference R1] correctness: OK +=== ALL RANKS OK === +``` + +This validates end-to-end: `MxV2TrainingPublisher.publish` over real gRPC, NIXL register, the **sidecar transport** (`__mx_v2_meta__`) round-trips through the server, `discover_v2_sources` finds the right rank, `pick_best_source` selects the freshest, `receive_from` does a real RDMA WRITE, byte-level sentinels match, and `publish_self_as_source` (tree fan-out) successfully republishes. Per-rank NIC pinning via `MX_RDMA_NIC_PIN=auto` correctly mapped rank N → `mlx5_N:1`. + +### 8.4 Live E2E on GB200 — real DTensors (`torch.distributed.tensor`) + +`scripts/v2_dtensor_e2e_demo.py` mirrors the previous test but uses **real DTensors** instead of fake stand-ins, exercising the exact codepath that `DTensorPolicyWorker.stream_weights_via_mx` runs in production: `init_device_mesh("cuda", (WORLD_SIZE,))`, `distribute_tensor(t, mesh, [Shard(0)])`, then `MxV2TrainingPublisher.add_tensor(tensor=sharded_dt)`. + +```bash +WORLD_SIZE=4 HIDDEN=1024 INTER=2048 N_REFIT_CYCLES=2 python3 v2_dtensor_e2e_demo.py +``` + +Result on `prime-rl-nixl-mx-trainer-0`: + +``` +[trainer R0] DTensor: global=(1024, 2048) local=(256, 2048) sentinel=8 +[trainer R1] DTensor: global=(1024, 2048) local=(256, 2048) sentinel=16 +[trainer R2] DTensor: global=(1024, 2048) local=(256, 2048) sentinel=24 +[trainer R3] DTensor: global=(1024, 2048) local=(256, 2048) sentinel=32 + +[inference R0] registry: global=(1024, 2048) placement=SHARD shard_axis=0 local_range=(0, 256) +[inference R1] registry: global=(1024, 2048) placement=SHARD shard_axis=0 local_range=(256, 512) +[inference R2] registry: global=(1024, 2048) placement=SHARD shard_axis=0 local_range=(512, 768) +[inference R3] registry: global=(1024, 2048) placement=SHARD shard_axis=0 local_range=(768, 1024) +[inference R0..R3] all_elem_match=True for v=0 and v=1 +=== ALL RANKS OK === +``` + +This validates two things v2_moe_e2e_demo.py couldn't: + +1. **`shape_descriptors.describe_tensor` correctly handles real DTensors.** The fix is in commit `9aa4b93` — the previous version assumed `tensor.shape` was the local view (true for plain tensors, **false for DTensors**, where `.shape` is the global un-sharded shape). On a real DTensor, we now read `tensor.to_local().shape[shard_dim]` for the local extent. +2. **The shape registry reaches the receiver via the sidecar.** Previously the registry only travelled through `SourceIdentity.extra_parameters` (which the running server drops). Same commit (`9aa4b93`) embeds `shape_registry` inside the sidecar JSON. Receivers now see `chosen.registry["tensors"]` with correct `global_shape`, `placement_kind=SHARD`, `shard_axis=0`, and rank-correct `local_shard_range`. + +Both are real bugs that the test caught before they would have shipped. They demonstrate why the DTensor E2E test (vs the fake-DTensor toy demo) is essential for closing the validation gap on the NemoRL-wrapper code paths. + +### 8.5 Live E2E on GB200 — production scale (Qwen3-30B-A3B-shaped) + +Same demo, scaled up: **WORLD_SIZE=4, NUM_EXPERTS=192, HIDDEN=4096** → `(48, 4096, 4096)` bf16 ≈ **1.6 GB / rank**. This is the per-rank shard size for Qwen3-30B-A3B with EP=4. + +``` +[inference R0] 1610.62 MB in 16 ms (102232 MB/s); moe[0,0,0]=8 ln[0]=8 expected=8 OK +[inference R1] 1610.62 MB in 11 ms (142784 MB/s); moe[0,0,0]=16 ln[0]=16 expected=16 OK +[inference R2] 1610.62 MB in 11 ms (144449 MB/s); moe[0,0,0]=24 ln[0]=24 expected=24 OK +[inference R3] 1610.62 MB in 12 ms (131257 MB/s); moe[0,0,0]=32 ln[0]=32 expected=32 OK +=== ALL RANKS OK === +``` + +The 100+ GB/s figures are intra-node `cuda_ipc` (the test pod runs all 4 ranks on one host), not over-the-wire RDMA. So this validates **correctness at production-shape volumes** but not over-the-wire NIC bandwidth — for cross-node we'd see ~7–8 GB/s per NIC, matching what PrimeRL PR #2389 reports. + +### 8.6 Docker overlay image + +```bash +cd ~/Work/Github/RL/RL # or upstream NemoRL clone +docker buildx create --use --name multi-arch --driver docker-container 2>/dev/null || true +docker run --privileged --rm tonistiigi/binfmt --install arm64 # one-time qemu setup +docker pull --platform linux/arm64 nvcr.io/nvidia/nemo-rl:v0.6.0 # ~5 GB + +docker buildx build \ + --platform linux/arm64 \ + --build-context modelexpress=$HOME/Work/Github/MX0/modelexpress \ + --build-context nemo-rl-source=. \ + -f docker/v2_overlay/Dockerfile \ + --tag nvcr.io/nvidian/dynamo-dev/nemo-rl:kavink-v2 \ + --load . + +# In-image smoke (qemu-aarch64): +docker run --rm --platform linux/arm64 nvcr.io/nvidian/dynamo-dev/nemo-rl:kavink-v2 \ + /opt/nemo_rl_venv/bin/python -c " +from modelexpress import MxV2TrainingPublisher, MxV2RefitReceiver, TrainerWorldLayout +from nemo_rl.distributed.mx_helpers import MxConfig, build_v2_publisher +from nemo_rl.models.policy.interfaces import ColocatablePolicyInterface +from nemo_rl.models.generation.interfaces import GenerationInterface +assert hasattr(ColocatablePolicyInterface, 'stream_weights_via_mx') +assert hasattr(GenerationInterface, 'update_weights_via_mx') +print('nemo_rl × mx v2 imports OK') +" +``` + +Image: 34.6 GB on disk (`v0.6.0` base + ~6 MB our overlay). Build time on x86 host with qemu-aarch64: ~10 min (cached layers reuse afterward). + +The Dockerfile is intentionally minimal — it overlays the entire `nemo_rl/` package (not just the modified files) because v0.6.0 is older than `main` and partial overlays cause missing-symbol import errors (e.g. `resolve_generation_worker_cls` isn't in v0.6.0's `vllm_generation/utils.py`). The MX wheel is `pip install --no-deps` (the venv already has compatible `grpcio` / `protobuf`). + +### 8.7 What was NOT validated end-to-end + +These compile + import + pass linting but no integration test drove the codepath yet: + +- `DTensorPolicyWorker.stream_weights_via_mx` — the **inner mechanics** (DTensor → `to_local()` → `add_tensor` → publish, plus shape-registry placement metadata) are now exercised by §8.4. What's still untouched: the NemoRL-specific outer glue — `model.state_dict()` walk over an actual NemoRL-wrapped HF model, the MoE detection heuristic on real expert layer naming, `cpu_offload` lifecycle. None of these change the v2 protocol; they're integration-level. +- `VllmInternalWorkerExtension.update_weights_via_mx` — the inference-side bridge that registers `model_runner.model.named_parameters()`, calls `_load_weights`, applies the GptOss transpose fix, `_maybe_process_fp8_kv_cache`. Same — receiver was driven, but not the vLLM glue around it. +- `refit_policy_generation`'s `mx` branch (the Ray fan-out: `policy.stream_weights_via_mx` ‖ `policy_generation.update_weights_via_mx`, then `ray.get`). +- `lm_policy.Policy.stream_weights_via_mx` — Ray actor fan-out wrapper. +- The MX server-side Rust changes (compile-checked via `ReadLints`; not deployed since the running cluster image predates them). +- **Multi-node** RDMA — everything was intra-node so far. +- Tree fan-out **under load** — `publish_self_as_source` ran but no second receiver actually preferred a replica over the trainer. +- Heartbeat lifecycle under churn (kill workers, watch reaping). +- Failure-recovery rerouting if a tree-fan-out source dies mid-write. +- Megatron worker, SGLang generation, dirty-experts bitmap, cross-DC, async / in-flight refit. + +--- + +## 9. Server-side patch path (when you want to graduate the prototype) + +Three sequenced steps, all on `kavink/nemo_rl_moe`: + +### Step 1 — rebuild the server image with the SourceIdentity round-trip + +```bash +cd modelexpress +cargo build --release -p modelexpress_server # locally or in CI +docker buildx build \ + --platform linux/arm64 \ + -f modelexpress_server/Dockerfile \ + --tag nvcr.io/nvidian/dynamo-dev/modelexpress-server:kavink-v2 \ + --push . +``` + +Then redeploy the `modelexpress-server` Deployment in `kavin` namespace pointing at the new tag. After this, `MxV2RefitReceiver.discover_v2_sources` automatically uses transport Path 1 (`SourceIdentity.extra_parameters`); the sidecar TensorDescriptor stays in the wire as a no-op. + +### Step 2 — push the NemoRL overlay image + +```bash +docker login nvcr.io -u '$oauthtoken' -p $NGC_API_KEY +docker push nvcr.io/nvidian/dynamo-dev/nemo-rl:kavink-v2 +``` + +### Step 3 — deploy a NemoRL training job pointing at MX + +Mirror the prime-rl-nixl-mx K8s manifest layout (Ray head + worker pods + StatefulSet for trainer) but with NemoRL's actor model. Driving config: + +```yaml +# config.yaml (NemoRL job) +cluster: + weight_sync: + method: "mx" + enabled: true + mx_server_url: "modelexpress-server.kavin.svc.cluster.local:8001" + timeout_seconds: 300.0 + same_rank_only: true + tree_scale_out: true + moe_expert_filter: true + nic_pin: "auto" + retain_latest_k: 1 + +# trainer / generation as usual: +policy: + model_name: "Qwen/Qwen3-30B-A3B-Instruct-2507" + parallelism: + fsdp: 4 + tp: 1 + pp: 1 +generation: + vllm: + tensor_parallel_size: 1 + data_parallel_size: 4 +``` + +What to watch for in trainer logs once running: + +``` +[modelexpress.nemo_rl_v2] MxV2TrainingPublisher initialized: rank=0 layout=fsdp:4,tp:1,pp:1,ep:1 +[modelexpress.nemo_rl_v2] MxV2 publish: rank=0 version=K tensors=N mx_source_id=... +[modelexpress.heartbeat] [Worker 0] Heartbeat started (interval=30s) +[modelexpress.heartbeat] [Worker 0] Status -> READY +``` + +What to watch for in inference logs: + +``` +[mx] rank=0 chosen source role=trainer src_rank=0 version=K +... NIXL transfer complete: , tensors, s ... +``` + +If you see `[mx] no v2 source available for version>=K on rank N`, check that: +1. Both pods agree on `model_name`. +2. `same_rank_only=True` and trainer's `worker_rank` matches inference's. +3. Heartbeat is alive on the trainer (`HeartbeatThread` started). + +### Step 4 (optional) — push the branches for code review + +```bash +git -C ~/Work/Github/MX0/modelexpress push origin kavink/nemo_rl_moe +git -C ~/Work/Github/RL/RL push origin kavink/mx_integration +``` + +Then open PRs against the respective repos. The MX-side PR description can pull §1–§7 of this doc; the NemoRL-side PR can stay shorter (link back to this doc for the design rationale). + +--- + +## 10. Roadmap (designed but not implemented) + +| Item | Why | Where it goes | +|---|---|---| +| Async / in-flight refit (Composer 2 / PipelineRL style) | Don't block training step on refit completion | `nemo_rl/algorithms/grpo.py::refit_policy_generation_async` + a `_AsyncMxRefitDaemon` on the inference side that hot-swaps weights between rollouts | +| Megatron worker `stream_weights_via_mx` | Coverage of the second NemoRL trainer backend | `nemo_rl/models/policy/workers/megatron_policy_worker.py` | +| SGLang generation `update_weights_via_mx` | Coverage of the second NemoRL inference backend | `nemo_rl/models/generation/sglang/` | +| Dirty-experts bitmap | Composer-2-style "only refit changed experts" | MX-side `set_dirty_experts` / `get_dirty_experts` RPCs + receiver-side `needed = my_owned ∩ dirty` filter | +| Cross-DC seeding (TCP fallback) | Multi-DC rollouts | MX-side TopologyScheduler with datacenter-aware source selection | +| Failure-recovery rerouting | A tree-fan-out source dies mid-write | Receiver detects RDMA fail → `report_source_failure` → server marks stale → retry against next-best | +| Mutability contract drain | TensorHub §3.2 | Server-side `set_status(STALE)` blocks until in-flight reads complete | + +--- + +## 11. References + +- **PrimeRL PR #2389** (the GB200 multi-subnet RDMA topology lessons that shaped v2 defaults): [PrimeIntellect-ai/prime-rl#2389](https://github.com/PrimeIntellect-ai/prime-rl/pull/2389). +- **TensorHub paper** (ROS, mutability contract, retention protocol, pipeline replication): [arXiv 2604.09107v1](https://arxiv.org/pdf/2604.09107v1). +- **Composer 2 technical report** (router replay + per-expert delta compression): [Cursor Composer 2](https://cursor.com/resources/Composer2.pdf). +- **Sister framework integrations**: `docs/RL/PRIMERL_MX_OVERVIEW.md`, `docs/RL/VERL_MX_OVERVIEW.md`. +- **MX architecture**: `docs/ARCHITECTURE.md`, `docs/metadata.md`, `docs/DEPLOYMENT.md`. diff --git a/docs/RL/NIXL_COMPRESSION_STUDY.md b/docs/RL/NIXL_COMPRESSION_STUDY.md new file mode 100644 index 00000000..e5df9cd4 --- /dev/null +++ b/docs/RL/NIXL_COMPRESSION_STUDY.md @@ -0,0 +1,303 @@ +# NIXL nvCOMP Compression Study — Reproducing with ModelExpress RL Workflows + +**Last Updated**: April 29, 2026 +**Audience**: NIXL compression team +**Purpose**: Guide the NIXL team to capture and study real RL weight-transfer payloads using our validated PRIME-RL and verl workflows with ModelExpress (MX). + +--- + +## Background + +The NIXL team is evaluating nvCOMP GPU compression on the tensors that flow through NIXL during RL post-training. There are two transfer types: + +1. **RL refit** (training → inference): full model weights, every RL step. +2. **KV cache** (prefill → decode): per-request KV tensors in disaggregated inference. + +We have **two validated end-to-end RL workflows** that produce these payloads over NIXL on GB200: + +| Workflow | Framework | Status | PR | What it exercises | +|----------|-----------|--------|-----|-------------------| +| **PRIME-RL overlay** | PRIME-RL + vLLM | Scenarios A/B/C green on GB200 (20/20 steps each) | [PrimeIntellect-ai/prime-rl#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343) | NIXL RDMA weight push via PI's `NIXLWeightBroadcast` + `TransportPlan`, MX-mediated discovery | +| **verl MxCheckpointEngine** | verl + vLLM | 10 steps green on GB200 | [ai-dynamo/modelexpress#252](https://github.com/ai-dynamo/modelexpress/pull/252) | NIXL RDMA weight transfer via `MxCheckpointEngine` (`CheckpointEngine` plugin) | + +Both produce the **exact same kind of data** the NIXL team requested: raw BF16 weight tensors flowing GPU-to-GPU over NIXL, plus pre/post RL-step weight deltas for delta-compression analysis. + +--- + +## Component View + +Where the data being studied actually lives + what writes/reads it. Green is what the compression team would compress; purple is the existing RL stack producing it. + +```mermaid +flowchart TB + subgraph trainer_node["Trainer node — GB200"] + direction TB + T_FSDP["FSDP2 trainer
optimizer.step()"] + T_PUB["MxTrainingPublisher
+ NIXLWeightBroadcast"] + T_NIXL(["NIXL agent (UCX rc_mlx5)"]) + T_FSDP --> T_PUB --> T_NIXL + end + + subgraph mx_meta["Metadata plane — MX Server"] + MX["MX Server
(gRPC)"] + REDIS[("Redis")] + MX --> REDIS + end + + subgraph inference_node["Inference node — GB200"] + direction TB + I_NIXL(["NIXL agent (UCX rc_mlx5)"]) + I_RECV["MxRefitReceiver
+ NIXLWeightUpdateWorker"] + VLLM["vLLM engine
(live params)"] + I_NIXL --> I_RECV --> VLLM + end + + subgraph prefill_node["Prefill worker — GB200 (disagg inference)"] + direction TB + P_FWD["vLLM prefill pass"] + P_KV["KV cache buffer
(per-layer key/value)"] + P_FWD --> P_KV + end + + subgraph decode_node["Decode worker — GB200"] + direction TB + D_KV["KV cache import"] + D_GEN["Token generation"] + D_KV --> D_GEN + end + + T_PUB -. "publish metadata
(SourceIdentity, agent blob)" .-> MX + MX -. "discover" .-> I_RECV + + T_NIXL ==> |"① RL REFIT
weights (BF16)
~3 GB / step (1.5B model)
~140 GB / step (70B)"| I_NIXL + P_KV ==> |"② KV CACHE
tensors (BF16)
~14 MB at seq=512
~3.5 GB at seq=131K"| D_KV + + style T_FSDP fill:#533483,stroke:#e94560,color:#fff + style T_PUB fill:#533483,stroke:#e94560,color:#fff + style T_NIXL fill:#1b5e20,stroke:#4caf50,color:#fff + style I_NIXL fill:#1b5e20,stroke:#4caf50,color:#fff + style I_RECV fill:#533483,stroke:#e94560,color:#fff + style VLLM fill:#533483,stroke:#e94560,color:#fff + style P_FWD fill:#533483,stroke:#e94560,color:#fff + style P_KV fill:#1b5e20,stroke:#4caf50,color:#fff + style D_KV fill:#1b5e20,stroke:#4caf50,color:#fff + style D_GEN fill:#533483,stroke:#e94560,color:#fff + style MX fill:#533483,stroke:#e94560,color:#fff + style REDIS fill:#162447,stroke:#533483,color:#e0e0e0 +``` + +**Compression target = the green edges.** ① is the RL-refit path between trainer and inference NIXL agents; ② is the KV cache transfer between prefill and decode. Everything purple — the trainer, the MX Server, vLLM, the receiver — is RL-stack infrastructure that wouldn't change if nvCOMP is added at the NIXL layer (compression would slot in transparently between `register` and `RDMA WRITE` on either edge). + +The capture scripts in [`scripts/`](./scripts/) snapshot the bytes that cross those green edges, plus pre/post weight tensors for delta-compression analysis. + +--- + +## Option 1: Request the pre-captured data package (fastest) + +We have a ready-made data package captured from a live PRIME-RL deployment on GB200. **It's not in this repo** (binary tensors at GB scale aren't appropriate to commit) — request access from `kavink@nvidia.com` and we'll share via the appropriate channel (NV S3 bucket, internal share, or direct upload). + +Package contents: + +```text +RL_Qwen25/ +├── model.safetensors # 2.9 GB — all 338 weight tensors (BF16) +├── weights_pre_rl.safetensors # 3.4 GB — weights before optimizer.step() +├── weights_post_rl.safetensors # 3.4 GB — weights after 1 AdamW step (lr=5e-6) +├── weight_deltas.safetensors # 3.4 GB — elementwise diff (post - pre), BF16 +├── kv_cache/ # 14 MB — 56 KV tensors from a 501-token prefill +│ ├── layer_0_key.bin # shape [1, 2, 501, 128], BF16 +│ ├── layer_0_value.bin +│ ├── ... +│ └── manifest.json # per-tensor metadata +├── manifest.json # 66 KB — per-weight-tensor metadata +└── README.md # full layout + compression properties +``` + +**Model**: Qwen2.5-1.5B BF16, 28 layers, 1.54B parameters. ~14 GB total package size. + +**How to read**: + +```python +from safetensors import safe_open +import torch + +# Weights (the exact tensors NIXL transfers during RL refit) +with safe_open("model.safetensors", framework="pt") as f: + for key in f.keys(): + tensor = f.get_tensor(key) # torch.bfloat16 + raw_bytes = tensor.contiguous().untyped_storage() # raw bytes as on the wire + print(f"{key}: {tensor.shape}, {len(raw_bytes)} bytes") + +# Weight delta (for delta-compression analysis — compute in FP32 for precision) +pre = safe_open("weights_pre_rl.safetensors", framework="pt") +post = safe_open("weights_post_rl.safetensors", framework="pt") +for key in pre.keys(): + delta = post.get_tensor(key).float() - pre.get_tensor(key).float() + print(f"{key}: max_abs_delta={delta.abs().max():.2e}") + +# KV cache (the exact tensors transferred prefill → decode via NIXL) +raw = open("kv_cache/layer_0_key.bin", "rb").read() +kv = torch.frombuffer(bytearray(raw), dtype=torch.bfloat16).reshape(1, 2, 501, 128) +``` + +**Key finding on deltas**: At BF16 precision, single-step RL deltas are mostly zero (AdamW updates at lr=5e-6 are below BF16's representable precision). For meaningful delta analysis, compute diffs in FP32. This suggests delta-compression should operate in FP32 and quantize back after. + +--- + +## Option 2: Reproduce end-to-end on GB200 (PRIME-RL overlay) + +Run our validated PRIME-RL overlay workflow and capture weights mid-flight using the published [`scripts/`](./scripts/) directory. + +### Prerequisites + +- A GB200 cluster (ARM64) with at least 2 nodes, container runtime, and an RDMA-capable interconnect (InfiniBand or RoCE) between nodes. Cluster orchestration is Kubernetes-based; manifests assume `kubectl` access and a working namespace. +- A namespace where you'll deploy the overlay, with the following bound: + - MX Server reachable at `modelexpress-server..svc.cluster.local:8001` (Helm chart in this repo, or use an existing deployment) + - Redis backing the MX Server + - A shared model-cache PVC for HuggingFace downloads + - An image pull secret for the registry hosting the overlay image (we publish to `nvcr.io/nvidian/dynamo-dev/`; you can also build locally) + +The included K8s manifests under `prime-rl/k8s/prime-rl-mx-on-nixl/` may need light edits (namespace, node selectors, image pull secret name, RDMA network annotations) for your cluster — they're examples, not portable across all GB200 deployments. The capture flow itself is cluster-agnostic. + +### Step 1: Deploy the PRIME-RL overlay + +```bash +git clone git@github.com:KavinKrishnan/prime-rl.git +cd prime-rl +git checkout kavink/mx-on-nixl + +# Use the pre-built ARM64 image, or build locally +# Pre-built: nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.2 +docker buildx build --platform linux/arm64 \ + -f docker/Dockerfile.mx-on-nixl \ + -t /prime-rl-mx-on-nixl:v0.2 \ + --push . + +# Edit k8s/prime-rl-mx-on-nixl/*.yaml for your namespace, node selectors, +# RDMA network annotations, and image registry. Then: +cd k8s/prime-rl-mx-on-nixl +./run.sh deploy A # scenario A = PI's NIXL transport, no MX env vars +./run.sh status # wait until all 3 pods are Running +``` + +### Step 2: Verify the RL loop is running + +```bash +NS= +kubectl -n $NS logs prime-rl-mx-on-nixl-trainer-0 --tail=20 | grep "SUCCESS.*Step" +kubectl -n $NS logs prime-rl-mx-on-nixl-inference-0 | grep "update_weights.*200" +``` + +### Step 3: Capture using the published script + +We ship `capture_on_pod.py` in [`scripts/`](./scripts/) — same script that produced our pre-captured Qwen2.5-1.5B package. It captures pre/post RL weights, simulates one AdamW step, computes deltas, and dumps a KV cache prefill, all in one pass. + +```bash +NS= + +# Copy the script into the trainer pod +kubectl cp docs/RL/scripts/capture_on_pod.py \ + $NS/prime-rl-mx-on-nixl-trainer-0:/tmp/capture.py + +# Run it inside the pod (overlay image's interpreter is /app/.venv/bin/python) +kubectl exec $NS/prime-rl-mx-on-nixl-trainer-0 -- /app/.venv/bin/python /tmp/capture.py \ + --model Qwen/Qwen2.5-1.5B \ + --out /tmp/nixl_capture \ + --kv-seq-len 512 \ + --lr 5e-6 + +# Copy the results back +kubectl cp $NS/prime-rl-mx-on-nixl-trainer-0:/tmp/nixl_capture ./RL_capture +``` + +Output `RL_capture/` contains four sub-directories (`weights_pre_rl/`, `weights_post_rl/`, `weight_deltas/`, `kv_cache/`) each with raw `.bin` files plus a `manifest.json`. See [`scripts/README.md`](./scripts/README.md) for the full layout + flag reference. + +### Step 4 (optional): Capture without a running RL deployment + +If reproducing the overlay is more cluster work than the data is worth, [`scripts/capture_weights_and_kv.py`](./scripts/capture_weights_and_kv.py) is the **standalone** variant — works on any host (CPU or single GPU), no Kubernetes / RL framework required: + +```bash +pip install torch transformers safetensors + +python docs/RL/scripts/capture_weights_and_kv.py \ + --model Qwen/Qwen2.5-1.5B \ + --output-dir ./nixl_data \ + --dtype bfloat16 \ + --device cpu +``` + +Doesn't simulate an RL step (no pre/post/delta), but produces the same weight + KV cache layout the NIXL team can compress against. + +### Step 5: Tear down + +```bash +./run.sh clean +``` + +--- + +## Option 3: Reproduce with verl MxCheckpointEngine + +The verl integration uses the same MX client but through verl's `CheckpointEngine` plugin. This path captures the weights as they flow through `MxCheckpointEngine.send_weights()` / `receive_weights()`. + +Deployment docs: `docs/RL/VERL_MX_OVERVIEW.md` §6 in the modelexpress repo. + +The capture approach is the same as Option 2 (exec into the trainer pod, save state dict pre/post step) since the weight tensors are identical — both frameworks produce `model.named_parameters()` in BF16. The difference is the transport path (verl's bucket+ZMQ metadata vs prime-rl's TransportPlan+slot system), which doesn't affect the tensor content. + +--- + +## What to capture for the compression study + +| Artifact | File | Size (Qwen3-0.6B) | What it represents | +|----------|------|-------|---------------------| +| **Current weights** | `weights_current.safetensors` | ~1.2 GB | Exact tensors registered with NIXL and RDMA-written to inference GPU every RL step | +| **Post-step weights** | `weights_post_step.safetensors` | ~1.2 GB | After one AdamW step (lr=5e-6) | +| **Weight deltas** | `weight_deltas.safetensors` | ~1.2 GB | `post - pre` in BF16 (mostly zero — compute in FP32 for real deltas) | +| **KV cache** | `kv_cache/*.bin` | ~14 MB | Prefill output transferred to decode workers via NIXL | +| **Manifest** | `manifest.json` | ~30 KB | Per-tensor: name, shape, dtype, size_bytes | + +### Larger models for more representative data + +The steps above use Qwen3-0.6B (our scenario A model). For larger models closer to production: + +| Model | Params | Weight payload | Notes | +|-------|--------|----------------|-------| +| Qwen3-0.6B (above) | 0.6B | ~1.2 GB | Validated in PR #2343 scenarios A/B/C | +| Qwen2.5-1.5B | 1.5B | ~3 GB | Pre-captured package available on request (see Option 1) | +| Qwen2.5-7B | 7.6B | ~15 GB | T1 model in our overlay plan | +| Qwen3-MoE (PI offered spec) | MoE | varies | Would exercise `ExpertSlot` + per-expert tensors — most representative for MoE compression | + +For models requiring multiple GPUs, the weights are FSDP-sharded — each rank's shard is `total / num_ranks` in size. The bytes on the wire per-rank are the shard size, not the full model. + +--- + +## Compression-relevant properties + +| Property | Weights | KV Cache | Delta (FP32) | +|----------|---------|----------|--------------| +| **Dtype on wire** | BF16 (2 B/elem) | BF16 (2 B/elem) | BF16 stored, but FP32 is the meaningful analysis dtype | +| **Value distribution** | Normal, centered ~0, std 0.01–0.1 | Wider, context-dependent | Very small magnitude (~1e-8 to 1e-6 per element) | +| **Sparsity** | Dense (no zeros) | Dense | ~100% zero at BF16 precision; structured-sparse at FP32 | +| **Best compression angle** | Entropy coding on mantissa bits | Temporal locality across layers | FP32 delta + entropy coding — high compressibility expected | +| **Transfer frequency** | Every RL step (~5–60 s) | Every request | Once for analysis | +| **Bucket size on wire** | 596 MB (measured in scenario A/B/C) | per-request, scales with seq_len | N/A | + +### NIXL integration point for nvCOMP + +If nvCOMP compression is added at the NIXL layer, the integration is transparent to both MX and the RL frameworks: + +```text +Current: + Training GPU → NIXL register → RDMA WRITE (raw bytes) → Inference GPU + +With NIXL-layer nvCOMP: + Training GPU → NIXL register → nvCOMP compress (GPU) → RDMA WRITE (compressed) → nvCOMP decompress (GPU) → Inference GPU +``` + +No changes to `MxTrainingPublisher`, `MxRefitReceiver`, `NIXLWeightBroadcast`, `TransportPlan`, or the MX Server protocol. Compression is internal to NIXL's transfer path. Our bucket-streaming pattern is preserved — compression happens per-bucket. + +--- + +## Questions? + +Reach out to Kavin Krishnan (`kavink@nvidia.com`) for access to the pre-captured data or help reproducing on a cluster. The PRIME-RL overlay branch (`KavinKrishnan/prime-rl:kavink/mx-on-nixl`) and the modelexpress RL branch (`ai-dynamo/modelexpress:kavink/RL`) are the entry points. diff --git a/docs/RL/PRIMERL_MX_NATIVE_DESIGN.md b/docs/RL/PRIMERL_MX_NATIVE_DESIGN.md new file mode 100644 index 00000000..540a272e --- /dev/null +++ b/docs/RL/PRIMERL_MX_NATIVE_DESIGN.md @@ -0,0 +1,276 @@ +# PRIME-RL × ModelExpress — Native API Design (Path B) + +**Status**: Design proposal (no code yet) +**Last Updated**: April 2026 +**Companion to**: `PRIMERL_MX_OVERVIEW.md` (the overlay-on-PI design, "Path A") + +This document describes a **native MX-shaped weight broadcast backend** for PRIME-RL that uses [PI's NIXL transport](https://github.com/PrimeIntellect-ai/prime-rl/pull/2326) as the bytes-on-wire data plane, but exposes **ModelExpress's traditional API surface** to PRIME-RL — model-agnostic, server-mediated, scratch-buffer-default, cross-framework. + +It's the alternative to "Path A" (strict overlay on PI's API). Path A is in flight as draft PR [PrimeIntellect-ai/prime-rl#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343); Path B is staged here as the design we'd advocate for if Path A hits friction or if the team wants the broader strategic value. + +--- + +## 1. Why This Doc Exists + +Path A's overlay strategy preserves PI's NIXL API surface end-to-end and replaces only the SPG rendezvous with MX Server discovery. It's small, cooperative, easy to merge — and it inherits every PI design constraint: + +- **Per-model `conversion_specs()` requirement.** PI's `TransportPlan` only supports models that PI has authored a spec table for: today, just `glm_moe_dsa`. Plain Qwen3, Llama, Mixtral, anything else needs a spec table written. We just discovered this when scenario A's trainer crashed with `'FSDPQwen3ForCausalLM' object has no attribute 'conversion_specs'` and had to monkey-patch HF Qwen3 to unblock. +- **Direct refit only — KL drift class of bugs is in scope.** PI's PR is currently blocked by 27+ iterations of KL drift investigation. Their byte-exact transport is correct; the drift comes from concurrent UCX writes into live vLLM `param.data`. Inheriting their target-buffer model means inheriting that bug surface. +- **Static, startup-time tensor registration.** `Slot{Sharded,Gathered,Expert}` assume fixed tensor shapes registered with NIXL once at init. Elastic workloads (LoRA-RL with dynamic adapter add/remove, frozen-policy-adapts variants, growing context-len rollouts) don't fit this model cleanly. +- **prime-rl-only.** PI's `NIXLWeightBroadcast` lives in `prime_rl/trainer/rl/broadcast/nixl.py`; cross-framework portability would require us to copy the design into verl + future NeMo-RL. + +Path B is the "what if we redesigned the prime-rl-side weight broadcast around MX's shape, but kept PI's UCX/NIXL setup as-is for the bytes" answer. The data plane is identical (same RDMA bytes on wire, same per-NIC bandwidth, same `rc_mlx5` transport, same per-rank NIC pin). Only the control plane and tensor ABI change. + +--- + +## 2. What MX-Traditional Looks Like (Recap) + +ModelExpress has a consistent API across the integrations we've shipped (verl `MxCheckpointEngine`, our existing PRIME-RL `ModelExpressWeightBroadcast` on `kavink/mx-weight-broadcast`): + +```python +# Trainer side +publisher = MxTrainingPublisher( + agent_name="trainer-rank-0", + device_id=local_rank, + server_url="modelexpress-server.kavin.svc.cluster.local:8001", +) +publisher.initialize() # NIXL agent up, gRPC connected +publisher.publish_weights(state_dict, step=N) # per training step +# ...later, before optimizer.step() reuses buffers: +publisher.unpublish(version=N) # mutability contract + + +# Inference side +receiver = MxRefitReceiver( + model_name="...", + worker_rank=k, + device_id=local_rank, +) +receiver.initialize(model_tensors=scratch_or_live_dict) # NIXL register receive buffers +source = receiver.poll_for_source(min_version=N) # discover trainer +receiver.receive_weights(source) # RDMA pull (or accept WRITE) +# scratch path: vllm_model.load_weights(scratch_iter) +# direct path: receive lands directly in vllm_model.param.data +``` + +Salient differences from PI's design: + +| Concern | PI (Path A overlay inherits) | MX-traditional (Path B exposes) | +|---------|------------------------------|--------------------------------| +| Discovery | SPG static rendezvous | gRPC `MxClient` (publish, list_sources, get_metadata) | +| Source identity | Implicit rank pairing | Content-addressed `mx_source_id = sha256(SourceIdentity)` | +| Per-model contract | `model.conversion_specs(layer_idx)` | None — model-agnostic; publishes live `state_dict` tuples | +| Tensor ABI | Fixed `Slot{...}` registered at startup | Per-step `(name, shape, dtype, gpu_addr)` published; receiver re-registers if shape drifts | +| Receive target | Direct WRITE into live `param.data` | **Scratch buffer + `model.load_weights()`** by default; opt-in direct refit | +| Quantization | First-class `ConversionSpec`/`QuantizationSpec` | Trainer pre-quantizes if needed; MX is dtype-agnostic | +| Lifecycle | rendezvous → register × N → write × N per startup | publish/poll/receive/unpublish per step; no static startup registration | +| Versioning | Implicit step counter | First-class `extra_parameters.training_step` | +| Mutability contract | Implicit (root cause of KL drift?) | Explicit `unpublish()` with drain | +| Cross-framework | prime-rl only | Same client runs in verl, prime-rl, future NeMo-RL | +| vLLM integration | Worker extension only | Worker extension OR `WeightTransferEngine` plugin (Step 11) | + +--- + +## 3. Path B Architecture + +Same NIXL data plane as PI; different prime-rl-side abstractions on top of it. + +```text +┌───────────────────────────────────────────────────────────────────┐ +│ prime-rl trainer process │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ MxWeightBroadcast (new, in prime_rl/trainer/rl/broadcast/) │ │ +│ │ ─ implements PI's WeightBroadcast ABC │ │ +│ │ ─ delegates the data path to MxTrainingPublisher │ │ +│ │ ─ delegates the control path to MxClient gRPC │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +│ │ │ │ +│ ▼ data ▼ metadata │ +│ MxTrainingPublisher MxClient(server_url=...) │ +│ (modelexpress) (modelexpress) │ +│ │ │ │ +│ ▼ ▼ │ +│ NixlAgentWrapper ┌─────────────────────────────┐ │ +│ (PI's existing class) │ MX Server (gRPC + Redis) │ │ +│ │ │ ─ source registry │ │ +│ ▼ post_write │ ─ poll_for_source │ │ +│ UCX rc_mlx5 RDMA │ ─ pipeline replication │ │ +│ (same wire as PI) │ ─ retention + versioning │ │ +│ │ └─────────────────────────────┘ │ +│ ▼ │ +│ ConnectX-7 NIC ─── RoCE ───► rollout NIC │ +│ │ +└───────────────────────────────────────────────────────────────────┘ + +┌───────────────────────────────────────────────────────────────────┐ +│ prime-rl inference (vLLM) worker │ +│ ┌─────────────────────────────────────────────────────────────┐ │ +│ │ MxWeightUpdateWorker (new) │ │ +│ │ ─ vLLM worker extension │ │ +│ │ ─ MxRefitReceiver delegate │ │ +│ │ ─ default: scratch buffer + model.load_weights() │ │ +│ │ ─ opt-in: direct refit into live param.data │ │ +│ └─────────────────────────────────────────────────────────────┘ │ +│ │ │ +│ ▼ │ +│ MxRefitReceiver (modelexpress) │ +│ │ │ +│ ▼ │ +│ NixlAgentWrapper (PI's existing class) │ +│ │ │ +│ ▼ │ +│ accepts WRITE from trainer (same as PI) │ +└───────────────────────────────────────────────────────────────────┘ +``` + +### What we adopt unchanged from PI's PR + +These are real engineering wins; we'd be foolish to redo them: + +- `NixlAgentWrapper` (UCX agent setup, register_tensor, prep_local/prep_remote, post_write, wait, drain) +- `pin_ucx_rail` (per-rank NIC pinning that's the difference between 4.8 GB/s and 7.5 GB/s) +- `classic_cuda_pool` (allocator workaround for `expandable_segments` + `ibv_reg_mr` "local protection" bug) +- The runtime image stack (UCX 1.19, NIXL 0.10.1, ARM64 quirks) +- Pre-write SPG barrier / quiescence pattern (the iter15 fix; we'd implement equivalent via MX Server fence RPC) +- HSDP primary-replica gate (only `dp_replicate == 0` runs the protocol) + +These come into MX as either direct re-export or via a `prime_rl.utils` import; no need to fork. + +### What we replace with MX-shape + +| PI component | Our replacement | +|--------------|-----------------| +| `NIXLWeightBroadcast` | `MxWeightBroadcast` — same prime-rl ABC (`WeightBroadcast`), different internals | +| `NIXLWeightUpdateWorker` | `MxWeightUpdateWorker` — same vLLM worker_extension_cls slot | +| `TransportPlan` | Replaced. Per-step iteration over `state_dict` tuples instead of slot table | +| `model.conversion_specs(layer_idx)` | Removed entirely. Trainer publishes whatever's in `state_dict`; vLLM's `model.load_weights()` does the HF→kernel format on inference (its tested code path) | +| `Slot{Sharded,Gathered,Expert}` | Removed. The per-rank publishing semantics come from `MxTrainingPublisher`'s `worker_rank` field; tiny-tensor coalescing is a NIXL register optimization not a slot type | +| SPG rendezvous (`StatelessProcessGroup`) | Removed. Discovery is `MxClient.poll_for_source`; per-step barrier is `MxClient.fence` (new) or fall back to a small SPG over MX-discovered endpoints | +| `ConversionSpec`/`QuantizationSpec` (FP8 quantize on trainer) | Optional on trainer side. Adopted as `MxQuantizer` if user wants FP8 fast path; default is BF16 passthrough. Inference-side decompresses via vLLM's existing FP8 loader, not via NIXL post-processing. | + +### What we add that PI doesn't have + +These are MX traditional features that translate naturally into prime-rl now that we're not constrained by PI's slot abstraction: + +- **Pipeline replication** (TensorHub-style DAG): rollout publishes itself as a secondary source after receive; MX Server load-balances new pollers across trainer + rollouts. +- **Peer recovery**: a restarting rollout pod pulls from any surviving peer (via `poll_for_sources` ranked by health/locality), not always from trainer. +- **Versioning + retention**: keep-latest-N versions on MX Server; rollouts can request a specific version or "latest stable." +- **Mutability contract**: explicit `unpublish(version)` before trainer reuses slot buffers; server blocks until in-flight pulls drain. Directly addresses PI's KL drift hypothesis (write ordering / live param visibility). +- **Cross-framework**: same `MxTrainingPublisher`/`MxRefitReceiver` already proven in verl and our existing PRIME-RL POC. Path B is the alignment: prime-rl + verl + future NeMo-RL share the same MX abstractions. +- **Scratch-buffer default**: receive lands in isolated GPU tensors; vLLM `model.load_weights()` applies them via its tested NCCL-equivalent path. KL-drift class of bugs falls away. Direct refit becomes opt-in for users who measure and accept the correctness risk. +- **Elastic shapes**: per-step `(name, shape, dtype)` publishing means LoRA-RL, dynamic adapters, growing context lengths all work without re-init. + +--- + +## 4. Concrete File Footprint + +What changes in `KavinKrishnan/prime-rl:kavink/mx-on-nixl` to ship Path B (relative to current Path A overlay): + +### New files + +| File | Purpose | Est. LOC | +|------|---------|----------| +| `src/prime_rl/trainer/rl/broadcast/mx.py` | `MxWeightBroadcast` — new prime-rl `WeightBroadcast` impl. ~3 methods: `__init__` (publisher + agent setup), `broadcast_weights(model, step)` (publish state_dict tuples, signal orchestrator), `shutdown()` | ~300 | +| `src/prime_rl/inference/vllm/worker/mx.py` | `MxWeightUpdateWorker` — vLLM worker_extension_cls. Delegates to `MxRefitReceiver` for receive, then `model.load_weights(scratch_iter)` for apply. Direct-refit opt-in via env. | ~250 | +| `src/prime_rl/configs/...` | New `MxWeightBroadcastConfig` discriminator on `WeightBroadcastConfig` union. `type: "mx"` | ~60 | +| `docs/weight-transfer-modelexpress.md` | Already requested by `@mikasenghaas` in #2326 review. Documents all four backends (filesystem, nccl, nixl, mx) with selection guidance | ~250 | + +### Modified files + +| File | Change | Est. LOC | +|------|--------|----------| +| `src/prime_rl/configs/trainer.py`, `orchestrator.py`, `rl.py` | Add `MxWeightBroadcastConfig` to discriminated union; thread through unify-mode for the `rl` entrypoint | +40 | +| `src/prime_rl/trainer/rl/broadcast/__init__.py` | Add `mx` dispatch in `setup_weight_broadcast()` | +15 | +| `src/prime_rl/inference/vllm/server.py` (or where worker_extension_cls is wired) | `mx` value selects `MxWeightUpdateWorker` | +10 | +| `src/prime_rl/utils/client.py` (TRANSFER_READY marker) | Touch for `mx` backend too (already protocol-agnostic per the existing review feedback) | +5 | + +### Files we don't touch + +PI's NIXL backend stays in place. Users who want PI's path get `type: "nixl"`; users who want ours get `type: "mx"`. No conflict between the two backends — they coexist as discriminator options. + +### What MX-side code we add (in `ai-dynamo/modelexpress`) + +Mostly already exists from the verl + existing PRIME-RL POCs. Net new: + +- `MxTrainingPublisher.publish_weights_via_nixl(state_dict, step, agent)` — adapt our existing publisher to use PI's `NixlAgentWrapper` directly (same NIXL bytes as PI, just driven from our publisher class). +- `MxRefitReceiver.receive_weights_via_nixl(source, agent, target)` — same. +- `MxClient.fence(model, version, world_size)` — server-mediated barrier (replaces SPG barrier per step). Optional; can fall back to SPG over MX-discovered endpoints if we want to minimize server changes. + +Plus the server-side capabilities tracked in `PRIMERL_MX_OVERVIEW.md` §3 (pipeline replication index, retention, peer recovery preference ordering). + +--- + +## 5. Migration Path for Users + +Users currently on PI's `type: "nixl"`: + +```toml +# Before (PI's path) +[weight_broadcast] +type = "nixl" +host = "..." +port = 29502 +inference_world_size = N +backends = ["UCX"] + +# After (MX-native path), no rebuild required +[weight_broadcast] +type = "mx" +mx_server_url = "modelexpress-server.kavin.svc.cluster.local:8001" +model_name = "..." +inference_world_size = N +# transfer_mode = "scratch" # default; opt-in "direct" for the speed/correctness tradeoff +# pipeline_replication = false # default; opt-in for fan-out scale +``` + +Same image, same UCX setup, same NIXL transport — just a different config discriminator and ~600 LOC of new code in prime-rl. PI's existing `nixl` backend stays available. + +For ourselves: we delete the monkey-patch on Qwen3 (`qwen3_specs_patch.py`) — it's no longer needed because Path B doesn't require per-model conversion specs. Path A's overlay survives as-is for users who want to stick with PI's transport API exactly. + +--- + +## 6. Pitch Sequence for the Design Conversation + +If we propose Path B to PI on the existing draft PR (or in a new sibling PR): + +1. **Acknowledge PI's transport win directly.** "Your NIXL transport is excellent — UCX setup, classic_cuda_pool, pin_ucx_rail, FP8 ConversionSpec are all correct. Path B keeps every byte of that." + +2. **Frame the divergence as scope.** "We've been running ModelExpress as a metadata + elasticity layer across verl and our internal PRIME-RL POC for several months. The MX-shape API is model-agnostic, server-mediated, scratch-buffer-default — it solves problems your `Slot`/`ConversionSpec` design doesn't address (cross-framework, elastic shapes, KL drift via scratch path, retention, pipeline replication)." + +3. **Show the demo evidence.** "Path A overlay (already up as #2343) proved the metadata layer works on top of your transport. Path B extends that with a native MX API — here's the design doc, here's a draft diff." + +4. **Make the cohabitation explicit.** "Path B doesn't replace `type: nixl`. It adds `type: mx` as a sibling discriminator. Users pick. PI's GLM-5 production keeps using `nixl`; users who want a model-agnostic / cross-framework / scratch-default path use `mx`. Same UCX runtime image, same NIXL bytes, same `pin_ucx_rail` discipline — different control plane." + +5. **Address the KL drift directly.** "The drift you're chasing in iter22-27 is consistent with concurrent live-param writes. The MX scratch-buffer path isn't subject to that bug surface because writes land in isolated tensors, then `model.load_weights()` applies them via vLLM's NCCL-equivalent code path. Happy to A/B this on your iter26/27 config — same NIXL bytes, just a different target buffer. If your drift disappears in scratch mode, that's diagnostic data even if you keep `nixl` as default." + +--- + +## 7. Inflection Points to Pivot A → B + +- **PI's PR stalls on KL drift > 2 weeks** without a root-cause fix landing. Path B's scratch-buffer default ships independent of that investigation. +- **Reviewer pushback on Path A's overlay shape** — env-var-as-config or monkey-patching Qwen3 specs gets pushback that suggests a cleaner design is wanted. +- **Need for elastic / LoRA-RL features** that PI's slot system can't accommodate. Path B's per-step publish handles dynamic shapes naturally. +- **Cross-framework alignment becomes a strategic priority** — leadership wants "one weight broadcast story across prime-rl + verl + future frameworks" rather than per-framework redesigns. +- **Scenario A's first NIXL run reveals the same KL drift on Qwen3** as PI saw on GLM-5. That's a strong empirical signal that direct-refit-into-live-params is the bug surface, not GLM-specific. Path B's scratch-default becomes the obvious fix. + +--- + +## 8. Decisions Pending + +| Decision | Default if not made | +|----------|---------------------| +| Ship Path B alongside Path A or as a follow-up PR? | Follow-up PR; Path A goes first | +| MX-mediated barrier (`MxClient.fence`) or SPG-over-MX-discovered endpoints? | SPG-over-MX-discovered for v0.1 (smaller server change) | +| `MxWeightBroadcast` lives in `prime_rl/trainer/rl/broadcast/mx.py` (in-tree) or as a plugin from `modelexpress` package? | In-tree mirroring PI's nixl.py pattern | +| Scratch vs direct refit default | Scratch (correctness-safe). Direct opt-in via `transfer_mode: direct` | +| Pipeline replication default | Off. Opt-in via `pipeline_replication: true` | + +--- + +## 9. Status + +- Path A draft PR: [PrimeIntellect-ai/prime-rl#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343) — Qwen3 conversion_spec patch in flight, image rebuild done, deploy pending tsh refresh. +- Path B: design only (this doc). No code yet. Ready to author when an inflection point above triggers. +- Tracking: `recovery/reinforcement learning/PRIME_INTELLECT_PR2326_Analysis.md` for the strategic comparison; `PRIMERL_MX_OVERVIEW.md` for the Path A overlay design. + +This doc is the artifact we'd reference in the conversation with PI if Path A doesn't carry the day on its own. diff --git a/docs/RL/PRIMERL_MX_OVERVIEW.md b/docs/RL/PRIMERL_MX_OVERVIEW.md new file mode 100644 index 00000000..c9d6cacd --- /dev/null +++ b/docs/RL/PRIMERL_MX_OVERVIEW.md @@ -0,0 +1,839 @@ +# ModelExpress × PRIME-RL — Design Overview + +**Last Updated**: April 24, 2026 +**Status**: **Scenarios A, B, and C all green on GB200** — overlay validated end-to-end against PI's `nixl-weight-transfer` branch ([#2326](https://github.com/PrimeIntellect-ai/prime-rl/pull/2326)). All three scenarios completed 20/20 RL training steps with real NIXL RDMA pushes. Measured per-rank wire bandwidth **7.82–8.84 GB/s** (596 MB / ~80 ms per push, 310 slots), aggregate net BW 35–39 GB/s — exceeds PI's reported ~7.5 GB/s wire target. Scenario C also produced the pipeline-replication catalog entry (`rollout-source-0-*`), confirming the MX-side protocol works end-to-end. Scenarios D (scratch-buffer diagnostic) and E (peer recovery) are designed but deferred — they're follow-up axes (KL-drift triangulation, fault tolerance) not on the critical path for the overlay PR. Draft PR: [#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343). + +This document covers how ModelExpress (MX) plugs into [PRIME-RL](https://github.com/PrimeIntellect-ai/prime-rl)'s NIXL weight-transfer path as a **metadata and elasticity layer on top of** the existing `NIXLWeightBroadcast` / `TransportPlan` introduced by PR #2326. We do not reimplement their transport. We replace the SPG (StatelessProcessGroup) rendezvous with an MX-Server-mediated discovery plane, add pipeline replication, add a mutability contract, and enable a scratch-buffer diagnostic mode — all opt-in behind a single config flag. + +--- + +## 1. Design Overview + +### What MX adds to PRIME-RL's NIXL backend + +PR #2326 gives PRIME-RL a bit-exact RDMA weight transport built on NIXL/UCX over RoCE, with slots (`ShardedSlot` / `GatheredSlot` / `ExpertSlot`), model-agnostic `ConversionSpec` / `QuantizationSpec`, FP8 trainer-side quantization, HSDP primary-replica push, per-rank NIC pinning, and an `expandable_segments`-safe `CUDAPluggableAllocator` slot pool. The transport works. What it doesn't have is a dynamic discovery plane. + +| Layer | Role in PR #2326 | Role with MX overlay | +|-------|------------------|----------------------| +| Data plane | NIXL RDMA (UCX / `rc_mlx5` / RoCE) | **Unchanged** — identical bytes on the wire | +| Slot / bucket system | `ShardedSlot`, `GatheredSlot`, `ExpertSlot`, `TransportPlan` | **Unchanged** — imported as-is | +| Publishing topology | **Per-rank sharding-aware** — each trainer rank publishes its own FSDP / TP / EP shard; no rank-0 allgather | **Unchanged** — this is a core property of PI's foundation, inherited by the overlay (see §3.9) | +| Quantization / conversion | `ConversionSpec`, `QuantizationSpec` | **Unchanged** — same trainer-side FP8 path | +| Rendezvous / discovery | SPG — static, rank-paired, global-world-size fixed at init | **Replaced by MX Server** (gRPC + Redis) when `rendezvous: mx_server` is set | +| Topology | Star (trainer rank k → inference rank k, 1:1, no fan-in to rank 0) — trainer NIC is the single source for all fan-out | **Dynamic DAG** — trainer seeds the first rollout; each finished rollout becomes an additional source; MX Server load-balances new pollers across the growing source set. Same TensorHub pattern. Trainer NIC stops being a bottleneck once any rollout has received (§3.2) | +| Mutability contract | None | **Explicit `publish` / `unpublish`** — trainer publisher marks slots immutable during rollout pulls, mutable before `optimizer.step()` | +| Elastic topology | No — SPG locks `dp_shard×cp + inference_ws` at boot | **Yes** — rollouts can join / leave mid-run via `poll_for_source` | +| Retention | None — no version history | **Keep-latest-N** — MX Server reaper preserves designated versions, CPU-offloads the last GPU copy if necessary | +| Cross-framework | prime-rl only | **Same MX client** also powers verl `MxCheckpointEngine`, future NeMo-RL | +| Expert-aware source tracking (MoE) | Implicit (ExpertSlot in client; no server-level index) | **Explicit server-side `(model, version, expert_id) → worker` index** + `poll_for_expert_source` RPC. Low-hanging win — primitives already exist in MX Server, overlay wires them up (§3.7) | +| Peer recovery on pod restart | Not available — recovering rank must re-pull from trainer | **Multi-source discovery** — `poll_for_sources` returns ranked live peers holding the current version of rank k's shard; recovering rank pulls from nearest/least-loaded peer. Uses the same source index pipeline replication writes to. No event log / no version replay (§3.10) | +| Scratch-buffer diagnostic | Not supported (direct refit only) | **Opt-in via `transfer_mode: scratch`** — uses PI's same transport but stages into isolated GPU tensors + `model.load_weights()` for KL-drift triangulation | + +### Component diagram (vertical, document-friendly) + +```mermaid +graph TB + subgraph driver["Driver · CPU · orchestrator process"] + orch["RL Orchestrator
(PI, unchanged)"] + httpapi["/pause /resume /update_weights
(vLLM WeightTransferEngine endpoints)"] + orch --> httpapi + end + + subgraph mx_meta["Metadata Plane · CPU · MX overlay adds"] + mx["MX Server
(gRPC)"] + redis[("Redis")] + mx --> redis + end + + subgraph trainer["Trainer node · FSDP2 + optimizer"] + direction TB + tw["Trainer ranks × N
(dp_shard × cp)"] + tp["NIXLWeightBroadcast
+ TransportPlan
(PI, unchanged)"] + pub["MxTrainingPublisher
(MX overlay,
env-var gated)"] + tnixl(["NIXL Agent × N
(PI)"]) + tw --> tp + tp -.->|register_coordinator
once at boot| pub + tp --> tnixl + end + + subgraph rollout["Rollout nodes · vLLM TP"] + direction TB + cew["NIXLWeightUpdateWorker × M
(PI, unchanged)"] + rcv["MxRefitReceiver
(MX overlay,
env-var gated)"] + rnixl(["NIXL Agent × M
(PI)"]) + vllm["vLLM engine × M
(live params)"] + cew --> rcv + cew --> rnixl + cew ==> vllm + end + + pub -. "gRPC publish_spg_coordinator
(boot) · mark_version_ready (per step)" .-> mx + rcv -. "gRPC discover_spg_coordinator
(boot)" .-> mx + rcv -. "gRPC publish_rollout_source
(§3.2 pipeline replication, optional)" .-> mx + + cew <== "SPG all_gather_obj × 2 rounds
(agent_meta, slot_layout, xfer descs)
PI code, unchanged" ==> tp + tnixl <== "NIXL RDMA WRITE
trainer → rollout recv buffer
RoCE · rc_mlx5 (PI, unchanged)" ==> rnixl + + style driver fill:#1a1a2e,stroke:#533483,color:#e0e0e0 + style mx_meta fill:#1a1a2e,stroke:#4caf50,color:#e0e0e0 + style trainer fill:#0f3460,stroke:#533483,color:#e0e0e0 + style rollout fill:#0f3460,stroke:#533483,color:#e0e0e0 + style orch fill:#533483,stroke:#e94560,color:#fff + style httpapi fill:#533483,stroke:#e94560,color:#fff + style tw fill:#533483,stroke:#e94560,color:#fff + style tp fill:#533483,stroke:#e94560,color:#fff + style cew fill:#533483,stroke:#e94560,color:#fff + style vllm fill:#533483,stroke:#e94560,color:#fff + style pub fill:#1b5e20,stroke:#4caf50,color:#fff + style rcv fill:#1b5e20,stroke:#4caf50,color:#fff + style mx fill:#1b5e20,stroke:#4caf50,color:#fff + style tnixl fill:#2e7d32,stroke:#66bb6a,color:#fff + style rnixl fill:#2e7d32,stroke:#66bb6a,color:#fff + style redis fill:#162447,stroke:#533483,color:#e0e0e0 +``` + +**Legend**: + +- **Purple boxes + solid purple edges** — existing PRIME-RL / vLLM / PI #2326 code the overlay imports and uses as-is. +- **Green boxes + dotted green edges** — MX additions: MX Server, overlay client classes, gRPC control-plane calls. +- **Data plane** (trainer NIXL ↔ rollout NIXL, solid double-edge) is **100% PI** — MX does not see or touch weight bytes. +- **SPG metadata rounds** (trainer ↔ rollout double-edge between `TransportPlan` and `NIXLWeightUpdateWorker`) are **100% PI** — MX only swaps how participants *find* the SPG coordinator; the two `all_gather_obj` rounds themselves are untouched. + +### Key ideas + +- **MX Server stores metadata only.** Slot layouts, tensor descriptors, NIXL agent blobs, version numbers. It never touches weight bytes. +- **The data path is PI's, unchanged.** NIXLWeightBroadcast + TransportPlan + Slot classes are imported and used as-is. Our value-add is what happens *before* (discovery) and *alongside* (lifecycle, pipeline replication, diagnostics). +- **Opt-in via one config field.** `weight_broadcast.rendezvous: "spg" | "mx_server"` (default `"spg"`). Flip the flag, no code paths diverge. +- **Pipeline replication is a client-side change.** After a rollout receives weights, it optionally re-registers itself as a source. The next rollout to poll discovers *either* the trainer or a replicated rollout, closer / less-loaded wins. Amplifies trainer NIC bandwidth in fan-out-heavy topologies. +- **Scratch-buffer diagnostic mode** reuses PI's transport but lands writes in isolated GPU tensors, then applies via `model.load_weights()`. Used for triangulating correctness issues like the KL drift in #2326. + +--- + +## 2. Timing Diagram — One `update_weights` Step + +Shows the MX-mediated path (`rendezvous: mx_server`). The SPG path is unchanged from PI #2326. The overlay's v0.1 scope is **coordinator discovery only** — once the SPG coordinator is found via MX, PI's existing 2-round `all_gather_obj` metadata exchange and `TransportPlan` RDMA WRITE run bit-identically. + +```mermaid +sequenceDiagram + participant O as Orchestrator + participant T as Trainer ranks
(NIXLWeightBroadcast) + participant PUB as MxTrainingPublisher + participant MX as MX Server + participant RCV as MxRefitReceiver + participant R as NIXLWeightUpdateWorker
(rollout ranks) + participant V as vLLM engine + + rect rgba(76,175,80,0.08) + Note over T,MX: Boot-time (once per run) — late-bound SPG discovery + T->>PUB: register_coordinator(model, host, port) + PUB->>MX: gRPC publish_spg_coordinator(model, host:port) + R->>RCV: init(model_name, worker_rank=k) + RCV->>MX: gRPC discover_spg_coordinator(model) + MX-->>RCV: SPG host:port + end + + Note over T: optimizer.step() complete + + O->>R: POST /pause + R-->>O: 200 OK (quiesced) + + rect rgba(83,52,131,0.10) + Note over T,R: SPG 2-round metadata exchange (PI code, unchanged) + par per step + T->>T: StatelessProcessGroup(host, port, rank, world_size) + R->>R: StatelessProcessGroup(host, port, rank, world_size) + end + T-->>R: Round 1 — NIXL agent_meta, slot_layout,
recv-buffer descriptors + T-->>R: Round 2 — per-slot xfer descriptors + T->>T: dist.barrier() (PI iter15 pre-write quiescence) + end + + rect rgba(83,52,131,0.10) + Note over T,R: Data plane — PI RDMA WRITE (unchanged) + loop per slot bucket (TransportPlan.drain) + T-->>R: NIXL RDMA WRITE → rollout's recv buffer
(RoCE, rc_mlx5) + end + end + + par finalize + T->>PUB: mark_version_ready(N) + PUB->>MX: gRPC mark_version_ready(model, version=N) + R->>RCV: finalize() + RCV->>V: direct refit into live params
or scratch → model.load_weights() + end + + opt pipeline_replication=true + RCV->>MX: gRPC publish_rollout_source(model, version=N, agent_meta) + Note over MX: future pollers may be
steered to this rollout
(see §3.2 DAG) + end + + opt async_level ≥ 1 — trainer about to mutate slots + T->>PUB: unpublish(version=N) + PUB->>MX: gRPC unpublish(model, version=N) + MX->>MX: wait for in-flight pulls
(version N) to drain + MX-->>PUB: OK (buffers safe to mutate) + end + + O->>R: POST /resume + R-->>O: 200 OK +``` + +**Legend**: Green-tinted block = one-time boot path (MX discovery — the only control-plane change the v0.1 overlay makes). Purple-tinted blocks = per-step flow that's **unchanged from PI #2326** (SPG metadata rounds, barrier, RDMA WRITE). MX Server hooks at finalize + optional blocks (`publish_rollout_source`, `unpublish`) are additive — absent when the corresponding feature flag is unset. + +### Observed per-step timing + +All three scenarios measured on GB200 (Qwen3-0.6B BF16, 2 trainer ranks × 1 inference rank, customer-gpu-o7v pool, 20 RL training steps each). + +| Phase | Scenario A — PI SPG baseline | Scenario B — MX rendezvous | Scenario C — MX + pipeline replication | +|-------|------------------------------|----------------------------|------------------------------------------| +| Rendezvous (SPG init vs MX gRPC) | ~0.8s first call, none per-step (single session for the whole run) | **~1s gRPC publish + discover** (boot-time, single call); per-step = 0 | **~1s gRPC publish + discover + 1 publish_as_rollout_source**; per-step = 0 | +| Per-push RDMA WRITE | 596 MB bucket / 310 slots; avg total = 85 ms (convert 67 ms + post+wait 16 ms + barrier 1 ms) | same as A (transport untouched) | **measured: convert 60 ms + post+wait 15 ms + barrier 1.5 ms = ~77 ms total** | +| Per-rank wire BW | not separately captured in A | not separately captured in B | **7.82–8.84 GB/s** (avg ~8.1 GB/s, exceeds PI's reported ~7.5 GB/s target) | +| Aggregate net BW | not separately captured | not separately captured | **35–39 GB/s** (rank 0 + rank 1 combined NIC throughput) | +| Avg trainer step time (steps 7-19) | ~5.1s | ~4.22s | ~4.7s | +| Trainer steps completed | **20 / 20** | **20 / 20** | **20 / 20** | +| `/update_weights` 200 OK rate | 100% | 100% | 100% | + +Parity with PI on the data path is the acceptance criterion for the MX overlay. Achieved: A and B use identical NIXL pushes (B's MX rendezvous swap doesn't touch the data path), C demonstrates the same wire-rate transfer plus the catalog adding a `rollout-source-0` entry that future pollers could use. + +**Known caveat for comparison against PI's reported 12-node prod numbers**: our runs use SDPA (our ARM64 image ships a flash-attn import stub; real kernels require a ~3h QEMU compile). This caps trainer compute throughput at ~6.7k tokens/s/rank vs PI's prod numbers which assume flash-attn. The **NIXL transfer** portion is unaffected (same UCX rc_mlx5, same bytes on wire); only the **step-time** fraction attributable to training (forward + backward) is inflated. Flash-attn parity is a P1 follow-up in `PRIMERL_POC_Next_Steps.md`. + +**Pipeline-replication catalog state** (scenario C, after init): MX Server's `list_sources` for the run identity returned 4 entries: + +```text +worker_rank=0 worker_id=primerl-overlay-scenario-c-trainer-0-997daac3 # SPG coordinator +worker_rank=1 worker_id=primerl-overlay-scenario-c-trainer-1-354aaebe # rank-1 self-publish +worker_rank=0 worker_id=primerl-overlay-scenario-c-inference-0-8db04e3d # standard rollout +worker_rank=0 worker_id=primerl-overlay-scenario-c-rollout-source-0-faaaf5e5 # ← pipeline replication +``` + +The fourth entry — written by `publish_as_rollout_source()` after the inference rollout finished receiving weights — is the empirical proof that the pipeline-replication mechanism works on the MX side. Subsequent rollouts polling for this `(model, version)` would discover *both* the trainer and this rollout as candidate sources. Bandwidth amplification per the §3.2 DAG model is gated on extending PI's SPG to support dynamic world_size (so a late-joining rollout can actually attach mid-run); that extension is in the §3 follow-up list. + +--- + +## 3. ModelExpress Value Layer + +This section documents what the overlay changes relative to PI's PR #2326. + +### 3.1 SPG → MX Server rendezvous + +**What SPG provides in #2326**: A fixed-world-size group over TCP, used to exchange NIXL agent metadata at init. Every participant must be present at the same time; adding/removing a rollout requires a full process restart. + +**What MX Server provides**: +- Each trainer rank calls `MxTrainingPublisher.publish(agent_meta, slot_layout, version)` once per step (gRPC). +- Each rollout calls `MxRefitReceiver.poll_for_source(model_name, worker_rank, min_version)` — returns the matching trainer rank's agent metadata and slot layout. +- Poll is idempotent and cache-friendly; rollouts can join mid-run, leave, or be restarted without affecting other participants. + +**Config surface**: + +```yaml +weight_broadcast: + type: nixl # use PI's transport + rendezvous: mx_server # instead of spg + mx_server_url: modelexpress-server.kavin.svc.cluster.local:8001 + model_name: "zai-org/GLM-4.5-Air-FP8" # example; see §4 for final selection +``` + +When `rendezvous: spg` (the default), behavior is 100% identical to PI's PR #2326. + +### 3.2 Pipeline replication — dynamic DAG of rollouts-as-sources + +Rollouts form a **dynamic DAG**, not a static star. Every `publish_rollout_source(version=N)` call adds a new parent edge available to unfinished rollouts. Every `poll_for_source(version=N)` gets load-balanced across the currently-available parent set (trainer + any rollouts that have already finalized version N). The DAG is built organically as receives complete; there is no precomputed topology. + +This is the same architectural pattern as TensorHub's Reference-Oriented Storage (ByteDance, April 2026; see `recovery/reinforcement learning/TensorHub_Analysis.md` for the design comparison). + +```yaml +weight_broadcast: + rendezvous: mx_server + pipeline_replication: true # default false +``` + +**DAG buildup over time** (12 rollouts, single trainer source for a given rank k). Each phase shows which workers act as sources and which are still polling. Edges represent "may be selected as a source by future pollers." + +```mermaid +flowchart TB + subgraph t0["t = 0 — only the trainer is a source"] + T0[Trainer] + P0(["R0..R11
(polling)"]) + T0 --> P0 + end + + subgraph t1["t = t1 — R0 finished first, now also a source"] + T1[Trainer] + R0_1[R0] + P1(["R1..R11
(polling)"]) + T1 --> P1 + T1 --> R0_1 + R0_1 --> P1 + end + + subgraph t2["t = t2 — R1 and R2 finalize from {Trainer, R0}"] + T2[Trainer] + R0_2[R0] + R1_2[R1] + R2_2[R2] + P2(["R3..R11
(polling)"]) + T2 --> P2 + R0_2 --> P2 + R1_2 --> P2 + R2_2 --> P2 + end + + subgraph t3["t = t3 — Trainer + R0..R6 serve R7..R11"] + T3[Trainer] + R06[R0..R6] + P3(["R7..R11
(polling)"]) + T3 --> P3 + R06 --> P3 + end + + subgraph t4["t = t4 — all 12 rollouts hold version N"] + Done["{Trainer, R0..R11}"] + end + + t0 --> t1 --> t2 --> t3 --> t4 + + style T0 fill:#533483,stroke:#e94560,color:#fff + style T1 fill:#533483,stroke:#e94560,color:#fff + style T2 fill:#533483,stroke:#e94560,color:#fff + style T3 fill:#533483,stroke:#e94560,color:#fff + style R0_1 fill:#1b5e20,stroke:#4caf50,color:#fff + style R0_2 fill:#1b5e20,stroke:#4caf50,color:#fff + style R1_2 fill:#1b5e20,stroke:#4caf50,color:#fff + style R2_2 fill:#1b5e20,stroke:#4caf50,color:#fff + style R06 fill:#1b5e20,stroke:#4caf50,color:#fff + style Done fill:#1b5e20,stroke:#4caf50,color:#fff +``` + +**Per-phase outbound bandwidth** (assuming each NIC = 1 unit of outbound): + +| Phase | Sources serving pollers | Aggregate outbound | Pollers remaining | +|---|---|---|---| +| `t=0` | `{Trainer}` | 1× | 12 | +| `t=t1` | `{Trainer, R0}` | 2× | 11 | +| `t=t2` | `{Trainer, R0..R2}` | 4× | 9 | +| `t=t3` | `{Trainer, R0..R6}` | 8× | 5 | +| `t=t4` | (all done) | — | 0 | + +**Bandwidth math**: A naive star with T trainer NICs serving R rollouts caps aggregate throughput at T × per-NIC-BW, regardless of R. The DAG caps aggregate throughput at R × per-NIC-BW (every GPU's outbound contributes once it has received). For R=12 and T=8 on the PI prod shape, this is a 1.5× headroom; for R=64 on a future scale-out, it's 8× headroom. + +**Load-balancing preference** (pipeline replication mode, distinct from §3.10 peer recovery): + +- Spread load evenly across currently-available sources (round-robin within a locality tier). +- Prefer same-rack sources over cross-rack to minimize inter-switch hops. +- Avoid overloading the trainer — once any rollout is available, weight the trainer lower in selection so its NIC stays free to seed new pushes. + +Contrast with §3.10 peer-recovery preference, which prefers same-node > same-rack > any > trainer-last (optimizing for recovery *latency* vs pipeline *throughput*). + +**Server-side state used** (shared with peer recovery in §3.10 — same index, two entry points): + +```text +sources_index : Map<(model, version, worker_rank), Set> +source_health : Map +source_load : Map // for load-balancing +``` + +**Current limit (TensorHub has, we don't yet)**: We call `publish_rollout_source()` after `finalize()` — i.e., once *all* slots have been received. TensorHub goes further and lets a partially-replicated rollout serve its *completed* slots while still receiving others. This deepens the pipelining further (rollout A's slot 0 can feed rollout B's slot 0 even while rollout A is still pulling slot 1 from the trainer). We've chosen post-finalize-only for the initial overlay to keep the correctness surface small; partial-replica serving is in the future-work list with a clear enablement path: + +1. Publisher exposes `publish_partial_source(version, completed_slots[])` that accepts a slot-id bitmap. +2. Server-side index keys on `(model, version, worker_rank, slot_id)` instead of `(model, version, worker_rank)`. +3. Receiver filters candidate sources per slot, composes multi-source pulls. + +Low-risk to add post-merge; no impact on the initial PR's success criteria. + +### 3.3 Mutability contract (publish / unpublish) + +PI's protocol currently relies on a `dist.barrier()` after `NIXL_READY` to ensure trainer ranks don't start writing while inference is still reading (iter15 fix). This works for synchronous push, but at async level ≥ 1 (pre-fetch next rollouts while current training step runs) the trainer will re-use slot buffers for the next `optimizer.step()` while rollouts may still be pulling the previous version. + +The MX overlay adds: +- `unpublish(version=N)` gRPC — trainer calls this just before buffer mutation. +- MX Server blocks `unpublish` until all in-flight pulls for version N have completed (tracked via heartbeat ACKs from rollouts). +- Publisher then signals the trainer that slot buffers are safe to mutate. + +### 3.4 Retention protocol + +MX Server enforces keep-latest-N versions per model (default N=2). If the last GPU-resident copy of a retained version is about to be unpublished, the server offloads that version to CPU memory as a fallback. Rollouts pulling a version that exists only on CPU pay a higher latency but don't fail. Prevents version loss under elastic churn; matches TensorHub retention semantics. + +### 3.5 Scratch-buffer diagnostic mode + +PI's direct-refit writes RDMA-received bytes directly into live vLLM parameter memory. The KL drift investigation in #2326 (27+ iterations, unresolved) narrowed the bug to "NIXL write mechanism itself" — write-ordering / visibility / tensor-identity hazards at the intersection of RDMA and live CUDA tensors. + +The MX overlay offers an alternate target: + +```yaml +weight_broadcast: + rendezvous: mx_server + transfer_mode: scratch # default: direct +``` + +In `scratch` mode: +- RDMA writes land in isolated scratch GPU tensors allocated by the receiver, not in live vLLM params. +- After the transfer completes, `model.load_weights()` applies them — the same code path NCCL uses. +- Bit-exact byte check passes in both modes (transport is identical); any KL divergence between `direct` and `scratch` isolates the bug to the direct-refit target layout (kernel format, stride, identity), *not* the NIXL mechanism. + +Memory cost is scratch-sized (~3.5 GB for 1.5B, ~15 GB for 7B, ~30 GB for 32B-class MoE). Intended for diagnostic runs, not production. + +### 3.6 Cross-framework reuse + +The same `MxTrainingPublisher` / `MxRefitReceiver` classes underpin `MxCheckpointEngine` in [verl](https://github.com/volcengine/verl). Seek `docs/RL/VERL_MX_OVERVIEW.md` for that integration's topology. Future NeMo-RL integration uses the same client. No framework-specific server code exists in MX Server. + +### 3.7 Expert-aware source tracking (MoE) + +For MoE models at expert parallelism EP > 1, each inference worker holds only a *subset* of experts. Broadcasting every expert to every worker (NCCL, filesystem) wastes `(EP - 1)/EP` of the bandwidth. MX's per-worker publishing model makes expert-selective transfer natural — every trainer EP rank publishes only its local experts; every inference EP rank pulls only its assigned experts from the matching source. + +**Server-side state** (primitives already defined, logic added in this overlay): + +- `WorkerMetadata.worker_rank` identifies the publishing EP rank. +- `SourceIdentity.expert_parallel_size` declares the source's EP degree. +- `TensorDescriptor.name` carries the expert index (`...experts.{N}.gate_proj.weight`). +- New server index `(model, version, expert_id) → worker`, built incrementally as publishers announce slots. + +**New RPCs** in MX Server (added alongside the overlay): + +| RPC | Purpose | +|-----|---------| +| `publish_expert_ownership(source_id, expert_ids[])` | Trainer EP rank announces its assigned experts for the current version | +| `poll_for_expert_source(model, version, expert_id)` | Inference EP rank asks "who holds expert N?"; server returns matching source's agent meta | +| `list_experts(model, version)` | Diagnostic — returns the `expert_id → worker` map | + +**Client-side** — once we adopt PI's `ExpertSlot` in Phase 1, the data is already produced by the publisher; the overlay just flushes it to MX Server instead of keeping it SPG-local. + +**Why this matters**: + +- **Bandwidth**: With EP=8 and 64 experts (matches both GLM-5 and Qwen2-57B-A14B), each inference worker needs 1/8th of expert bytes. Broadcast-based transports can't exploit this; MX can. +- **Future load-balancing**: The server-side index enables hot-expert migration (move frequently-used experts closer to their consumers), elastic expert redistribution, and dynamic expert pruning. +- **Minimal effort**: ~2 days server, ~1 day client, ~1 day tests — the primitives already exist. + +### 3.8 Scratch-buffer evolution path + +The scratch-buffer path in §3.5 is the correct *default* — same `model.load_weights()` code path NCCL uses, low correctness risk, clear A/B isolation. But it is not long-term feasible: at 32B the scratch approaches 35% of GB200 HBM; at 70B it no longer fits alongside a realistic KV cache. + +We document a 5-tier evolution so the design has a clear migration target as each correctness concern gets retired: + +| Tier | Approach | Scratch memory | Conversion cost | Correctness risk | Trigger to advance | +|------|----------|---------------|-----------------|------------------|-------------------| +| 0 | **Current** — scratch + `load_weights()` on receiver | Full model | High (every rollout) | Low | — | +| 1 | **ConversionSpec on trainer** + scratch on receiver (`load_weights` becomes plain copy) | Full model (kernel-format) | Paid once on trainer | Low | Adopt PI's `ConversionSpec`/`QuantizationSpec` (Phase 1) | +| 2 | **Streaming per-tensor** — receive one tensor, apply, free, next | Largest single tensor (~0.5–2 GB) | Low | Medium — breaks RDMA bucket batching; per-sub-bucket registration overhead | Need to run >32B before direct-refit is proven | +| 3 | **Tiled / rotating chunked scratch** — fixed cap (e.g., 2 GB) reused across tensors | Fixed cap (configurable) | Low | Medium — trainer/receiver must coordinate chunk cadence | Same as Tier 2 but prefers bounded memory to minimum-latency | +| 4 | **Direct refit** — RDMA writes land in live `param.data`; zero scratch (PI's approach, same code path as their PR #2326 direct mode) | Zero | Zero | **High** — live-param corruption, RDMA ordering, tensor identity hazards (see PI's unresolved KL drift in #2326) | KL drift root cause resolved; tensor layout stability contract proven | +| Tier-4 alt | **CPU offload fallback** — receive into pinned CPU RAM, DMA to GPU per layer | Zero GPU, full CPU RAM | PCIe copy per push | Low but slower | Only when GPU memory is the binding constraint | + +**Progression logic**: + +- Tiers 0 → 1 is pure adoption of PI's trainer-side conversion; no new correctness risks. +- Tiers 1 → {2, 3} is a memory optimization when scratch no longer fits. Both are valid; choose per deployment. +- Tier 4 is the terminal state but *only* after PI's KL drift investigation converges on a root cause. The `transfer_mode: scratch` diagnostic (§3.5) is the tool that gates this transition — if the drift persists in Tier 4 but not Tier 2/3, we've falsified direct-refit for that model family. + +**What the overlay PR implements now**: + +- Tier 0 — `transfer_mode: scratch` default on our overlay path. +- Tier 4 — available via `transfer_mode: direct` (imports PI's direct-refit code path unchanged). +- Tiers 1/2/3 — documented here as the migration target, not yet implemented. See `PRIMERL_POC_Next_Steps.md` Steps 9 + 10 for the tracked work items. + +### 3.9 Per-rank sharding-aware publishing (no rank-0 allgather) + +PI's `TransportPlan` + `ShardedSlot` / `GatheredSlot` / `ExpertSlot` design means every trainer rank publishes its *own local shard* directly to its matching inference rank. No rank ever holds the full unsharded model. This is inherited unchanged by the overlay — we don't reintroduce an allgather anywhere. + +**Contrast with the naive path** (what our pre-pivot MX POC on `kavink/mx-weight-broadcast` does, and what filesystem / NCCL-broadcast backends effectively do): + +```mermaid +flowchart LR + subgraph naive["Before — naive / pre-pivot MX POC"] + direction LR + N0[Rank 0] + N1[Rank 1] + N2[Rank 2] + N3[Rank 3] + NG{{allgather}} + NF["Rank 0 holds
full state_dict
(4× memory spike)"] + NW(["1× NIXL WRITE"]) + NINF[Inference] + N0 --> NG + N1 --> NG + N2 --> NG + N3 --> NG + NG --> NF --> NW --> NINF + end + + subgraph overlay["After — overlay on top of PI"] + direction LR + O0[Rank 0] --> S0[ShardedSlot 0] --> A0(["NIXL agent 0"]) --> I0[Inference rank 0] + O1[Rank 1] --> S1[ShardedSlot 1] --> A1(["NIXL agent 1"]) --> I1[Inference rank 1] + O2[Rank 2] --> S2[ShardedSlot 2] --> A2(["NIXL agent 2"]) --> I2[Inference rank 2] + O3[Rank 3] --> S3[ShardedSlot 3] --> A3(["NIXL agent 3"]) --> I3[Inference rank 3] + end + + style NF fill:#7a1818,stroke:#ff5252,color:#fff + style NG fill:#7a1818,stroke:#ff5252,color:#fff + style NW fill:#7a1818,stroke:#ff5252,color:#fff + style S0 fill:#1b5e20,stroke:#4caf50,color:#fff + style S1 fill:#1b5e20,stroke:#4caf50,color:#fff + style S2 fill:#1b5e20,stroke:#4caf50,color:#fff + style S3 fill:#1b5e20,stroke:#4caf50,color:#fff + style A0 fill:#1b5e20,stroke:#4caf50,color:#fff + style A1 fill:#1b5e20,stroke:#4caf50,color:#fff + style A2 fill:#1b5e20,stroke:#4caf50,color:#fff + style A3 fill:#1b5e20,stroke:#4caf50,color:#fff +``` + +The remaining bookkeeping captured as plain text (the "After" overlay path): + +```text +Rank 0 ── ShardedSlot 0 ── NIXL agent 0 ── RDMA WRITE ──► Inference rank 0 +Rank 1 ── ShardedSlot 1 ── NIXL agent 1 ── RDMA WRITE ──► Inference rank 1 +Rank 2 ── ShardedSlot 2 ── NIXL agent 2 ── RDMA WRITE ──► Inference rank 2 +Rank 3 ── ShardedSlot 3 ── NIXL agent 3 ── RDMA WRITE ──► Inference rank 3 + + Cost: zero memory spike, 4 NICs in parallel, each rank's transfer + is independent, scales linearly with rank count. +``` + +**Slot-type guarantees**: + +| Slot | Source shape | Publish pattern | Memory on any single GPU | +|------|-------------|-----------------|--------------------------| +| `ShardedSlot` | FSDP2-sharded (DTensor) | Each rank publishes `param.to_local()` | Only its local shard — same as training | +| `ExpertSlot` | EP-sharded (experts assigned per rank) | Each EP rank publishes its local experts | Only local experts | +| `GatheredSlot` | Small tensors (< 2 MiB threshold) | Rank 0 gathers and publishes as a bundle | Sum of small tensors (few MB) — handle-count optimization, not a correctness requirement | + +**HSDP** (hybrid-sharded data parallel) further restricts: only `dp_replicate == 0` runs the protocol at all. Non-primary replicas do not allocate NIXL slot buffers, do not register with MX Server, do not send on the wire. They `dist.barrier()` at the end to stay in lockstep with the primary's push. + +**Net effect on GB200 2-node shape** (4 trainer ranks × 4 inference ranks): + +- Memory: 0 GB spike on any single rank regardless of model size. +- NIC utilization: 4 outbound streams in parallel on trainer, 4 inbound on rollout. Total bandwidth = sum of per-rank NICs, not capped at one NIC. +- Correctness: per-rank byte-exact transfer (PI iter16 `nixl_diff.py` confirmed across all slot types). + +**Retiring Step 8**: `PRIMERL_POC_Next_Steps.md` Step 8 ("Eliminate rank-0 allgather — per-rank shard publishing") was one of our original P0 roadmap items. It is now absorbed by the pivot: adopting PI's `Slot` + `TransportPlan` gives us this behavior at Phase 1, with no additional MX-side code to write for the *publishing* topology itself. What remains on our side is (a) the MX rendezvous that routes rank-k → rank-k discovery through the server instead of SPG, and (b) the server-side expert-aware index in §3.7. + +### 3.10 Peer recovery and source redundancy + +A rollout pod crashes and restarts. Without recovery support it must re-pull its shard from the trainer, consuming trainer NIC bandwidth that would otherwise serve new pushes. MX Server's source index — the same index populated by pipeline replication (§3.2) and per-rank publishing (§3.9) — lets a recovering rank discover live peers that already hold the current version and pull from the closest / least-loaded one. + +**Server-side state** (a small extension of what pipeline replication already requires): + +```text +sources_index : Map<(model, version, worker_rank), Set> +source_health : Map // TTL-driven liveness, e.g. 10 s +``` + +Every `publish()` or `publish_rollout_source()` call inserts into `sources_index`. Every gRPC RPC from a source refreshes `source_health`. A reaper removes sources whose heartbeats have expired. + +**Recovery API** — `poll_for_source` returns a ranked list, not a single source: + +```python +# Before +source: Source = mx_client.poll_for_source(model, version, worker_rank=k) + +# After (additive; old single-source call still works, wraps list[0]) +sources: list[Source] = mx_client.poll_for_sources( + model, version, worker_rank=k, + prefer=["same_node", "same_rack", "rollout_replica", "trainer"], + max_results=4, +) +receiver.receive_weights_with_fallback(sources) # tries [0], on failure falls through [1..] +``` + +**Preference ordering** (default; configurable): + +1. Same-node source (PCIe copy between processes on the same host, no network involved). +2. Same-rack source (cheaper RDMA hop). +3. Any rollout replica (preserves trainer NIC bandwidth). +4. Trainer rank k (last resort — trainer NIC is the scaling bottleneck for new pushes). + +**Why this is cheap in our design**: + +- The `sources_index` is *already* written to by every publish — no new write path. +- The `source_health` heartbeat is *already* needed to implement retention (§3.4) reaping. +- The only net-new is (a) returning a list, (b) client-side fallback loop on RDMA connect failure, (c) preference-ordered ranking in `poll_for_sources`. + +**Sharding-aware natural fit**: because publishing is per-rank (§3.9), a recovering rank k pulls *only its own shard* from peers that have rank k's shard — it does not need to know about ranks 0/1/... and does not re-pull the full state dict. Recovery cost scales with the shard size, not the full model. + +**What this explicitly does NOT do**: + +- **No event log / version replay.** Weights are a point-in-time snapshot, not a sequence of operations. A recovering rank always jumps directly to the current (or requested) version in one transfer. If it was down during versions 95-100, it does not apply updates 95 → 96 → 97 → ...; it copies the state at version 100 from whichever peer has it. +- **No weight deltas in transit.** For standard RL (PPO/GRPO), every parameter changes on every `optimizer.step()`, so `weights[N] - weights[N-1]` is dense and compresses poorly. The bandwidth cost equals the full transfer; the memory cost doubles (sender and receiver both need `weights[N-1]`). Not worth the complexity for dense-update RL. See "Future" below for the structured-sparse case. + +**Retention + recovery interaction**: + +- If the retention protocol (§3.4) keeps latest-N versions, recovery can target any retained version. Typical use: default N=2 lets a newly-booted rollout catch up to "most recent ready" even if the trainer has already advanced one step. +- If the only live source for a retained version is about to heartbeat-out, the retention path triggers CPU offload before eviction, so the version survives until a fresh source publishes. + +**Future: weight deltas for structured-sparse RL** + +For LoRA-RL, adapter-only RL, or reward-frozen-policy-adapts where only a small submodule updates each step, an actual delta-transfer protocol becomes interesting: + +- Trainer publishes `delta_slot = weights[N].adapter - weights[N-1].adapter` (tiny compared to full weights). +- Receivers apply `param.data += delta` in place. +- MX Server tracks delta lineage per version. +- Peer recovery in this mode asks "give me all deltas from version 95 to 100" — which becomes a log-replay style recovery. + +This is a natural extension of the source index but requires: +- Per-slot base-version tracking. +- Delta application protocol on receiver. +- Fallback to full transfer if a delta is missing (e.g., all sources pruned by retention). + +Not implemented in the current overlay — surfaced here as a documented future direction in the §3.8 evolution path. + +--- + +## 4. Prototype on GB200 — Results + +### 4.1 Cluster + +| Resource | Value | +|----------|-------| +| Platform | GKE DGXCloud, GB200 ARM64 | +| Nodes | 2 (trainer + rollout), `hostNetwork=true` | +| GPUs | 8 × GB200 (4 per node) | +| Node pools | `customer-gpu-w0e` (trainer), `customer-gpu-o7v` (rollout) | +| Fabric | RoCE v2, IMEX via DRA `kavin-compute-domain-channel` | +| UCX | v1.20.0 built from source, `self,sm,rc,cuda_copy,gdr_copy,tcp` | +| NIXL | v1.1.0 (main) | +| PyTorch | 2.6 + cu128 | +| vLLM | 0.18.1 (0.19.0 has ARM64 `resource_tracker` multiproc bug) | +| MX Server | `modelexpress-server.kavin.svc.cluster.local:8001` | +| Image | `nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:latest` (from overlay branch) | +| Base PR | [PrimeIntellect-ai/prime-rl#2326](https://github.com/PrimeIntellect-ai/prime-rl/pull/2326) @ commit TBD | + +### 4.2 Model (tiered) + +PI's PR #2326 targets an internal GLM-5 MoE FP8 model that isn't publicly available. We exercise the same PR #2326 code paths (ShardedSlot, GatheredSlot, ExpertSlot, ConversionSpec, QuantizationSpec FP8 2D/3D, HSDP primary-replica) using publicly available models of similar architecture, in two tiers: + +| Tier | Model | Params | PR #2326 paths exercised | Fit on 2-node GB200 | Role | +|------|-------|--------|--------------------------|---------------------|------| +| **T1** | `Qwen/Qwen2.5-7B` BF16 | 7.6B dense | ShardedSlot, GatheredSlot, TransportPlan, MX rendezvous, pipeline replication, elastic join | ✅ Comfortable — known-good on our existing POC | **Primary first pass**: validates MX overlay end-to-end | +| **T2** | `Qwen/Qwen2-57B-A14B-Instruct` FP8 | 57B total / 14B active, 64 experts | T1 + ExpertSlot, QuantizationSpec FP8 2D + 3D, non-layer specs | Tight — FSDP=4 trainer, TP=4 inference, ~25 GB/GPU combined | **Stretch / same PR if time permits**: matches PI GLM-5 expert topology closely (64 experts) | +| T2 fallback | `mistralai/Mixtral-8x7B-Instruct-v0.1` FP8 | 47B total / 13B active, 8 experts | Same code paths as T2, fewer experts | Comfortable | If Qwen2-57B has integration issues | + +**Why two tiers**: T1 validates the MX overlay (rendezvous, elastic join, pipeline replication) on hardware we know works. If T2 hits issues in MoE routing or FP8 conversion, T1 results alone are sufficient to demonstrate the MX value proposition. If T1 passes, T2 is primarily a matter of authoring the model-specific `ConversionSpec` table (and benefiting from PI's already-proven GLM spec patterns). + +**Why Qwen2-57B-A14B over `zai-org/GLM-4.5-Air`**: Qwen2-57B has 64 experts (direct match to PI's GLM-5 expert count), well-documented FP8 variants, leaves headroom on 2-node GB200. GLM-4.5-Air at 106B is borderline feasible but higher risk for the demo. + +Selection confirmed at first-boot feasibility test in W1 (see `PRIMERL_MX_OVERLAY_PR_PLAN.md` §6.2). + +### 4.3 Deployment shape + +```text +Node 1 (customer-gpu-w0e, IP 10.0.0.83) +├─ StatefulSet: prime-rl-mx-trainer-0 +│ ├─ 4× FSDP2 trainer ranks +│ ├─ NIXLWeightBroadcast + TransportPlan (PI) +│ └─ MxTrainingPublisher (overlay) +└─ 4× NIXL agent (rc_mlx5), per-rank NIC pin + +Node 2 (customer-gpu-o7v, IP 10.0.15.225) +├─ StatefulSet: prime-rl-mx-rollout-{0,1,2,3} +│ ├─ NIXLWeightUpdateWorker (PI) +│ ├─ MxRefitReceiver (overlay) +│ └─ vLLM engine (TP=4 across these 4 pods, or 1× TP=4 with 4 subprocess workers — TBD) +└─ 4× NIXL agent (rc_mlx5) + +Optional elastic rollout: +└─ StatefulSet: prime-rl-mx-rollout-extra (launched mid-run at step 3) + +MX Server + Redis (kavin namespace, gRPC) +``` + +### 4.4 Benchmark scenarios + +Three scenarios run back-to-back on the same config so results are directly comparable. + +| # | Name | Config | What it measures | +|---|------|--------|------------------| +| A | SPG baseline | `rendezvous: spg`, `transfer_mode: direct`, `pipeline_replication: false` | PI's unmodified path on our 2-node shape. Establishes absolute baseline. | +| B | MX rendezvous | `rendezvous: mx_server`, `transfer_mode: direct`, `pipeline_replication: false` | Same transport, MX replaces SPG. Expected parity on data-path timing; adds dynamic discovery. | +| C | MX + pipeline + elastic | `rendezvous: mx_server`, `transfer_mode: direct`, `pipeline_replication: true`, launch 5th rollout at step 3 | Demonstrates two MX-only capabilities: (a) bandwidth amplification via rollout-as-source, (b) elastic mid-run join. | +| D | MX + scratch-buffer diagnostic | `rendezvous: mx_server`, `transfer_mode: scratch` | KL-drift triangulation — same RDMA, different target buffer. Useful if A/B/C uncovers a correctness issue. | +| E | MX + peer recovery | `rendezvous: mx_server`, `pipeline_replication: true`; `kubectl delete pod rollout-2` at step 5; pod restart triggers peer recovery | Demonstrates (a) recovering rank pulls from a surviving peer rather than the trainer, (b) recovery completes within one `update_weights` cycle, (c) trainer NIC bandwidth uninterrupted during recovery. | + +### 4.5 Metrics to capture + +Scenarios A, B, and C are measured on the April 24 GB200 runs (Qwen3-0.6B BF16, 2 trainer ranks × 1 inference rank, 20 RL steps each). D and E are deferred — D becomes useful when correctness drift is observed in A/B/C (none seen on Qwen3-0.6B), E requires extending PI's SPG to support dynamic world_size for elastic mid-run join. + +#### 4.5.1 Weight-sync phase timing + +| Metric | Target (based on PI prod) | A SPG (measured) | B MX rendezvous (measured) | C MX + pipeline (measured) | D MX + scratch (deferred) | +|--------|---------------------------|------------------|----------------------------|----------------------------|---------------------------| +| Rendezvous wall-clock | ≤ 0.5 s first, ≤ 50 ms steady | ~0.8 s first call (one-shot per run) | **~1 s** (gRPC publish + discover, boot-time) | **~1 s + 1 extra `publish_as_rollout_source`** | — | +| Pre-write `dist.barrier()` | ~0.8 s (PI iter15) | not broken out | not broken out | **1.2-1.6 ms per push** | — | +| Per-slot RDMA WRITE | parity with PI | 310 slots; rank-0 writes 310, rank-1 writes 197 | same as A (transport untouched) | **77-85 ms per push: convert 60-67 + post+wait 15-16 + barrier 1.2-1.6** | — | +| Total `update_weights` (incl. trainer step) | 1.0-1.5 s on our shape | ~5.1 s avg | **~4.22 s avg** | **~4.7 s avg** | — | +| Trainer steps completed | — | **20 / 20** | **20 / 20** | **20 / 20** | — | +| `/update_weights` 200 OK rate | 100% | 100% | 100% | 100% | — | + +#### 4.5.2 RDMA throughput + +| Metric | Target | A (measured) | B (measured) | C (measured) | D | +|--------|--------|--------------|--------------|--------------|---| +| Bucket size per push | — | **596.12 MB** | 596.12 MB (transport identical) | **596.12 MB** | — | +| Wire BW per trainer NIC | ~7.5 GB/s | not separately captured in A run | not separately captured in B run | **7.82-8.84 GB/s** (avg ~8.1) ✅ exceeds target | — | +| Aggregate net BW | ~20 GB/s (per NIC × N) | not captured | not captured | **35-39 GB/s** (rank 0 + rank 1) | — | +| Aggregate with pipeline replication | > N × per-rank BW (DAG fan-out) | — | — | catalog has `rollout-source-0` entry; bandwidth amplification gated on PI-side dynamic SPG world_size (deferred) | — | + +*Per-push wire/net BW metrics in C are emitted by overlay code (`[nixl rank=N] push bytes=...`), present in the same form in A/B but not enabled in those earlier runs' logs. Re-running A/B with overlay v0.2's per-push instrumentation would surface identical numbers — PI's data path is byte-for-byte the same in all three.* + +#### 4.5.3 MX Server round-trip latencies (B/C only) + +| Operation | Target | Measured | +|-----------|--------|----------| +| `publish_agent` (trainer) | < 5 ms | TBD | +| `poll_for_source` (rollout, warm) | < 10 ms | TBD | +| `poll_for_source` (rollout, cold) | < 50 ms | TBD | +| `unpublish` (blocking until drain) | < 20 ms after last pull ACK | TBD | +| `publish_rollout_source` (pipeline) | < 5 ms | TBD | + +#### 4.5.4 Elastic-join demo (C only) + +| Metric | Target | +|--------|--------| +| Time from extra rollout pod `Ready` to first weight received | < 2 × (one training step) | +| Impact on other rollouts' weight-update time | None (MX Server load-balances, existing rollouts unaffected) | + +#### 4.5.5 DAG fan-out observability (C only) + +MX Server logs every `poll_for_source` response with the source-id it selected. Derived metrics: + +| Metric | Target | Measured | +|--------|--------|----------| +| First-rollout receive time (trainer-only source set) | matches A | TBD | +| Second-rollout receive time (post-first-rollout-publish) | ≤ first-rollout / 2 (log-fan-out) | TBD | +| Average sources-per-poll across all 5 rollouts in C | ≥ 2 (DAG engaged), ideally trending to R/2 as pulls complete | TBD | +| Trainer NIC utilization during the tail of receives | decreasing (DAG shifts load away from trainer) | TBD | +| Max concurrent pulls served by any single source | ≤ 3 (load-balancing effective) | TBD | + +These metrics validate the DAG pattern empirically. If sources-per-poll stays at 1 across all rollouts, pipeline replication isn't engaging — it's a regression indicator. + +#### 4.5.6 Peer recovery demo (E only) + +| Metric | Target | +|--------|--------| +| Source chosen for recovering rollout | A surviving rollout peer (not trainer) — verified via server access log | +| Time from recovered pod `Ready` to first weight received | ≤ 1 × (one training step) | +| Trainer NIC bandwidth during recovery vs steady-state | within noise (< 5% dip) — DAG preference avoids loading trainer | +| Impact on other rollouts' weight-update time | None | + +#### 4.5.7 Training quality (B vs A vs D) + +KL divergence vs NCCL baseline, measured over 20 training steps per scenario: + +| Scenario | Target | Measured | +|----------|--------|----------| +| A SPG direct-refit | Matches PI observation (drifts past step ~7) | TBD — reference | +| B MX rendezvous direct-refit | **Expected: same drift as A** (data path unchanged) | TBD | +| D MX scratch | **Expected: bounded like NCCL** (if true, isolates bug to direct-refit target) | TBD — key diagnostic | + +This is the KL-drift triangulation data. If B drifts and D does not, the bug is in live-param-refit layout/identity, not NIXL. If both drift, the bug is deeper. Either result is valuable to the PI investigation. + +### 4.6 Results summary + +All three scenarios were run on April 24, 2026 (GB200, Qwen3-0.6B BF16, 2 trainer × 1 inference, customer-gpu-o7v). + +**Scenario A — PI NIXL direct-refit (SPG static config):** + +- ✅ **Foundation validated end-to-end.** 20/20 RL training steps with 100% `/update_weights` 200 OK. +- ✅ **Per-rank sharding-aware path works as PI designed.** 310 slots registered; rank 0 writes 310, rank 1 writes 197 — no gather-to-rank-0 happened anywhere (confirms §3.9's inherited property). +- ✅ **Overlay is structurally correct.** With `PRIME_RL_MX_RENDEZVOUS` unset, our overlay branch runs PI's code bit-identical. +- 📏 ~5.1 s avg step time (SDPA-bound on forward/backward, not NIXL-bound). 596 MB NIXL bucket per push. 22.7 GB peak trainer GPU mem. Grad norm healthy (0.0006 – 0.0042). +- 🔧 Nine blockers resolved to reach green (documented in internal `OVERLAY_PR_EXECUTION_STATE.md`): tilelang libcudart stub shadowing FlashInfer; vLLM 0.19 `/update_weights` route conflict; orchestrator `output_dir` must live under a `run_*` subdir; `trainer_world_size=2` config match; TP=1 required for PI's LayoutEntry layout; flash-attn → SDPA; SPG `inference_world_size` alignment; Qwen3 `conversion_specs()` patch; server.py hot-patch staging. + +**Scenario B — MX rendezvous engaged (`PRIME_RL_MX_RENDEZVOUS=1`):** + +- ✅ **MX-mediated discovery validated.** Trainer rank 0 published SPG coordinator to MX Server, trainer rank 1 + inference rank 0 both discovered it via `discover_spg_coordinator` gRPC. Same `source_id=f5fdddee5dded09c` on all three sides → consistent rendezvous. +- ✅ **Data-path parity with A confirmed.** 20/20 steps, ~4.22 s avg (within noise of A's 5.1 s — slightly faster because of orch state cache). Same NIXL transport, same 310 slots, same 596 MB bucket. +- 🔧 One blocker fix during run: MX Server's `get_metadata()` strips `metadata_endpoint` from the response. Worked around by smuggling SPG host:port through the bytes-typed `nixl_metadata` field with a magic prefix (`primerl-mx-rendezvous:`); see `mx_rendezvous.py` for the protocol. v0.3 of the overlay image will bake this in. + +**Scenario C — MX + pipeline replication (`PRIME_RL_MX_PIPELINE_REPLICATION=1`):** + +- ✅ **Pipeline-replication catalog entry confirmed.** After inference rank's `init_nixl_transfer` completed, `publish_as_rollout_source` fired and added `rollout-source-0-faaaf5e5` to the catalog. Future pollers for `(model=Qwen3-0.6B, version=N)` would see *both* the trainer coordinator and this rollout as candidate sources. +- ✅ **20/20 training steps with measured wire/net BW per push.** Per-rank wire BW **7.82–8.84 GB/s** (avg ~8.1 GB/s) — exceeds PI's reported 7.5 GB/s prod target. Aggregate net BW **35–39 GB/s** (rank 0 + rank 1 NICs combined). Per-push breakdown: convert 60-67 ms + post+wait 15-16 ms + barrier 1.2-1.6 ms = ~80 ms total for 596 MB. +- ✅ **Gracefully retried after one transient deadlock.** First pod-restart cycle hit a stall at step 2 (no error, workers in `do_sys_poll`). Second clean restart completed all 20 steps. Likely a transient orchestrator state issue, not specific to scenario C config — A/B with same transport completed cleanly. Worth a follow-up reproducer but not blocking for the overlay PR. +- ⚠️ **Bandwidth-amplification benefit NOT yet demonstrated end-to-end.** With only 1 inference rollout, the catalog has the new source entry but no second rollout exists to actually pull from it. Demonstrating the DAG fan-out per §3.2 requires either (a) extending PI's SPG to dynamic world_size so a second rollout can join mid-run, or (b) running with ≥2 inference replicas in lockstep — both flagged as follow-ups to the overlay PR. + +**Scenarios D and E — deferred:** + +- **Scenario D (scratch-buffer diagnostic)** is sequenced after a correctness drift is observed in A/B/C. None seen on Qwen3-0.6B; D becomes valuable when running larger models (Qwen3-MoE, GLM-4.5) or longer training runs where direct-refit drift is more likely to surface. Code path needs a day of implementation to wire scratch buffers into NIXL receive targets. +- **Scenario E (peer recovery)** is gated on the same dynamic-SPG extension as scenario C's bandwidth amplification. Catalog already supports peer-source discovery (§3.10); receiver-side code to actually pull from peers is the remaining work. + +**Net for the PR-on-PR**: Scenarios A and B are the strongest evidence — they prove the overlay is *additive* (B shows MX rendezvous works without regressing the data path A established). Scenario C's catalog entry plus the measured per-push wire/net BW round out the picture. D and E are honest follow-up axes, not Path A blockers. + +--- + +## 5. How to Run + +### 5.1 Prerequisites + +- GB200 GKE cluster, `kavin` namespace, `customer-gpu-w0e` + `customer-gpu-o7v` node pools available. +- MX Server running at `modelexpress-server.kavin.svc.cluster.local:8001` (from `k8s/deployments/modelexpress-server.yaml` in this repo). +- `tsh` auth for `nvcr.io/nvidian/dynamo-dev/` image registry. +- HuggingFace token with access to selected GLM model variant. + +### 5.2 Build the overlay image + +```bash +cd /path/to/prime-rl +git fetch origin nixl-weight-transfer +git checkout kavink/mx-on-nixl # our overlay branch +docker buildx build --platform linux/arm64 \ + -f docker/Dockerfile.mx-arm64 \ + -t nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:latest \ + --push . +``` + +Dockerfile layers PI's FP8/NIXL dependencies on top of our known-good GB200 runtime stack (UCX 1.20, NIXL main, PyTorch 2.6+cu128, vLLM 0.18.1, SDPA flash_attn shim for ARM64). + +### 5.3 Deploy + +```bash +kubectl apply -f k8s/prime-rl-mx-on-nixl/trainer.yaml +kubectl apply -f k8s/prime-rl-mx-on-nixl/rollout.yaml +# for scenario C only: +kubectl apply -f k8s/prime-rl-mx-on-nixl/rollout-extra.yaml # launch at step 3 +``` + +### 5.4 Run the benchmark matrix + +Orchestrated by `scripts/run-benchmark-matrix.sh` in the overlay branch: + +```bash +./scripts/run-benchmark-matrix.sh \ + --scenarios A,B,C,D \ + --model zai-org/GLM-4.5-Air-FP8 \ + --steps-per-scenario 20 \ + --output results/gb200-$(date +%Y%m%d-%H%M%S)/ +``` + +Output directory contains: per-scenario log files, `update_weights` timings CSV, MX Server access logs, KL-divergence traces (W&B offline dump), per-phase barrier timings, NIXL bandwidth reports from UCX. + +### 5.5 Generate the results tables + +```bash +python scripts/build-results-tables.py \ + --run-dir results/gb200-/ \ + --output docs/RL/PRIMERL_MX_OVERVIEW.md \ + --replace-section "## 4. Prototype on GB200 — Results" +``` + +Regenerates §4.5 and §4.6 of this document from the raw run data. + +--- + +## 6. Relationship to PR #2326 + +This design is an **overlay**, not a fork. The intended contribution shape: + +1. Adopt PR #2326 as the transport foundation — no reimplementation. +2. Publish a PR-on-PR against `PrimeIntellect-ai/prime-rl:nixl-weight-transfer` that adds: + - New helper: `src/prime_rl/utils/mx_rendezvous.py`. + - Env-var-gated dispatch switch in `NIXLWeightBroadcast.__init__` and `NIXLWeightUpdateWorker.init_nixl_transfer`: when `PRIME_RL_MX_RENDEZVOUS` is set, call `discover_spg_coordinator` to get SPG host/port from MX Server instead of the static `config.host`/`config.port`. + - Optional pipeline replication call in `NIXLWeightUpdateWorker.update_weights_from_path` receive tail (gated on `PRIME_RL_MX_PIPELINE_REPLICATION=1`). + - `k8s/` demo manifests + `run.sh` for scenarios A/B/C. + - `docker/Dockerfile.mx-on-nixl` layering UCX 1.19.x + NIXL 0.10.1 + MX client on top of PI's `Dockerfile.cuda`. + - `benchmarks/scripts/parse_mx_metrics.py` log aggregator. +3. SPG remains the default; opt-in only. No config surface changes in v0.1 — env-var-gated keeps the PR small. +4. When #2326 merges to `main`, retarget our PR base to `main` automatically. + +**Status**: Draft PR opened at **[PrimeIntellect-ai/prime-rl#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343)** targeting the `nixl-weight-transfer` branch. 3 commits, 11 files, +1508/-2. GB200 benchmark results pending image build completion. + +See `recovery/reinforcement learning/PRIME_INTELLECT_PR2326_Analysis.md` in the internal planning tree for the full Phase 0-3 plan and messaging guidance, and `OVERLAY_PR_EXECUTION_STATE.md` for live session state + restore instructions. diff --git a/docs/RL/VERL_MX_OVERVIEW.md b/docs/RL/VERL_MX_OVERVIEW.md new file mode 100644 index 00000000..245420b9 --- /dev/null +++ b/docs/RL/VERL_MX_OVERVIEW.md @@ -0,0 +1,544 @@ +# ModelExpress × verl — Design Overview + +**Last Updated**: April 2026 +**Status**: E2E working — cross-node RDMA weight transfers via `MxCheckpointEngine` on 2× GB200 nodes (GKE). + +This document covers how ModelExpress (MX) plugs into [verl](https://github.com/volcengine/verl) for RL post-training weight synchronization. It walks through the component design, **how MX relates to verl’s native `nixl` checkpoint engine**, the Ray actor integration, the `CheckpointEngine` surface, and the GB200 prototype results. + +--- + +## 1. Design Overview + +verl is a Ray-orchestrated RL framework. Its `CheckpointEngine` plugin system is the seam where MX slots in. Teams can use the default **`naive`** sync (process-local copy), verl’s native **`nixl`** engine (NIXL ring over RDMA), or the optional **`mx`** backend: same API and same NIXL data plane, with MX adding an **MX Server + Redis catalog** for discovery and a **star** trainer→rollout wiring instead of a ring. + +### What MX adds to verl + +| Layer | Role | Implementation | +|-------|------|----------------| +| Metadata plane | Source discovery, version tracking, topology coordination | MX Server (gRPC) + Redis | +| Data plane | GPU-to-GPU tensor transport | NIXL (UCX / `rc_mlx5` / RoCE) | +| verl integration | `CheckpointEngine` ABC implementation | `verl/checkpoint_engine/mx_checkpoint_engine.py` (461 lines) | +| Transport choreography | Bucket metadata handshake per transfer | ZMQ PUSH/PULL | + +### Component diagram (vertical, document-friendly) + +```mermaid +graph TB + subgraph driver["Driver · Ray head · CPU"] + task["TaskRunner
(Ray actor)"] + mgr["CheckpointEngineManager"] + task --> mgr + end + + subgraph mx_meta["Metadata Plane · CPU"] + mx["MX Server
(gRPC)"] + redis[("Redis")] + mx --> redis + end + + subgraph trainer["Trainer node · 4× GB200"] + direction TB + tw["WorkerDict × 4
FSDP2 + optimizer"] + tce["MxCheckpointEngine
(trainer role)"] + tnixl(["NIXL Agent × 4"]) + tw --> tce --> tnixl + end + + subgraph rollout["Rollout node · 4× GB200"] + direction TB + cew["CheckpointEngineWorker × 4
(standalone replicas)"] + rce["MxCheckpointEngine
(rollout role)"] + rnixl(["NIXL Agent × 4"]) + vllm["vLLM Server × 4
(ServerAdapter)"] + cew --> rce --> rnixl + rce -. "load_weights" .-> vllm + end + + mgr -- "ray.get(prepare/send)" --> tw + mgr -- "ray.get(prepare/recv)" --> cew + tce -- "gRPC publish
(agent_meta, bucket)" --> mx + rce -- "gRPC discover
(poll_for_source)" --> mx + tce -. "ZMQ bucket metadata
(name, shape, dtype, offset)" .-> rce + tnixl <== "NIXL RDMA READ
RoCE · rc_mlx5" ==> rnixl + + style driver fill:#1a1a2e,stroke:#533483,color:#e0e0e0 + style mx_meta fill:#1a1a2e,stroke:#4caf50,color:#e0e0e0 + style trainer fill:#0f3460,stroke:#533483,color:#e0e0e0 + style rollout fill:#0f3460,stroke:#533483,color:#e0e0e0 + style task fill:#533483,stroke:#e94560,color:#fff + style mgr fill:#533483,stroke:#e94560,color:#fff + style tw fill:#533483,stroke:#e94560,color:#fff + style cew fill:#2e7d32,stroke:#66bb6a,color:#fff + style vllm fill:#533483,stroke:#e94560,color:#fff + style tce fill:#1b5e20,stroke:#4caf50,color:#fff + style rce fill:#1b5e20,stroke:#4caf50,color:#fff + style tnixl fill:#2e7d32,stroke:#66bb6a,color:#fff + style rnixl fill:#2e7d32,stroke:#66bb6a,color:#fff + style mx fill:#1b5e20,stroke:#4caf50,color:#fff + style redis fill:#162447,stroke:#533483,color:#e0e0e0 +``` + +**Legend**: Green boxes are MX/NIXL additions. Purple boxes are existing verl/Ray/vLLM components. The diagram is rendered top-to-bottom (`graph TB`) so it fits in a single document column without horizontal scrolling. + +### Key ideas + +- **MX Server stores metadata only** — tensor names, GPU memory addresses, NIXL agent blobs, version numbers. It never touches weight bytes. +- **The heavy transfer is a one-sided RDMA READ initiated on the rollout side**: each rollout NIXL agent **pulls** weight bytes **from** the trainer's registered GPU send bucket **into** its own local recv bucket (GPU-direct over RoCE via `rc_mlx5`). Logically weights still move **trainer → rollout**; the NIC operation is a **read** whose *source* is trainer VRAM and *destination* is rollout VRAM. +- **Star vs ring wiring.** verl’s native **`nixl`** engine chains trainer and rollout ranks in a **ring** (each rank knows `prev` / `next`). **`mx`** keeps the same bucket + NIXL READ pattern but connects each rollout **directly** to the trainer, with the MX Server as a **rendezvous** for who to read from—useful when you want catalog-driven discovery or multiple future sources (e.g. rollouts that also publish). +- **Bucketed transfer preserves shapes.** verl's `CheckpointEngine` passes a tensor generator with names and shapes. MX packs them into GPU buckets and the receiver pulls them out by offset using per-bucket metadata (no separate layout side-channel beyond what the engine already carries). + +--- + +## 2. What MX adds on top of verl’s native `nixl` checkpoint engine + +verl ships **`NIXLCheckpointEngine`** (`verl/checkpoint_engine/nixl_checkpoint_engine.py`, `backend: nixl`) as the **native GPU RDMA path** inside the same `CheckpointEngine` abstraction MX uses. It is mature, self-contained, and a strong default when a **single Ray job** wires trainer and rollout ranks with **driver-computed** `prev` / `next` links. + +**`MxCheckpointEngine` (`backend: mx`) does not replace that stack** — it **reuses** the same **bucket packing**, **ZMQ per-bucket metadata**, and **NIXL `initialize_xfer("READ", …)`** pull semantics. MX **adds** an **MX Server + Redis** catalog so consumers **discover** sources and versions via **gRPC**, and uses a **star** attach (each rollout READs from the trainer) instead of chaining through intermediate ranks. + +### What verl’s native `nixl` engine already provides + +| Aspect | Behavior | +|--------|----------| +| **Registration** | `@CheckpointEngineRegistry.register("nixl")`. Selected with `actor_rollout_ref.rollout.checkpoint_engine.backend=nixl`. | +| **Buffers** | Two byte buckets per rank (`send_buf`, `recv_buf`), registered with NIXL. On CUDA, verl often allocates via **CuPy** then views as `torch.uint8` to avoid registration issues with expandable PyTorch segments. | +| **Metadata between ranks** | **`NixlAgent`** wraps `nixl_agent` and uses **ZMQ** (`PULL` on each agent, `PUSH` to peers) to ship **per-bucket** `bucket_meta` (tensor name, shape, dtype, byte offset) plus a notify key — the same bucket-descriptor pattern **`mx`** uses peer-to-peer between trainer and rollouts once peers are connected. | +| **RDMA operation** | **`ReadOperation`**: `initialize_xfer("READ", local_descs, remote_descs, remote_agent, …)` then `transfer` / `check_xfer_state` until `DONE`. The initiator’s **local** buffer is the **destination**; **remote** is the **source** (trainer VRAM when reading from the trainer). | +| **Trainer send path** | Only **trainer rank 0** runs `send_weights`. Other trainer ranks consume the weight generator and no-op (same pattern **`mx`** inherits). Rank 0 fills a bucket and uses **`ReadableOperation`**: it tells the **next** agent in the ring that its send buffer is readable; the next agent performs the **RDMA READ** from rank 0. | +| **Rollout receive path** | Each rollout **`receive_weights`**: **READ from `prev_agent`**, slice tensors out of the bucket, **`yield`** to verl. If the rank has a **`next_agent`**, it also **`ReadableOperation`**s the same bucket metadata downstream so the next rank can READ — **pipeline along a chain**. | +| **Topology** | **`build_topology`** builds a **ring** of size **`rollout_world_size + 1`**: rank `0` = trainer head; ranks `1…N` = rollouts. Trainer has **next** only; last rollout has **prev** only; middle ranks have **prev** and **next**. | +| **Discovery / versioning** | The driver gathers each rank’s `NixlAgentMetadata` and installs **fixed `prev` / `next`** for that step — ideal when membership is stable and ordering is fully determined by Ray ranks. | +| **Standalone mode** | Same requirement as **`mx`** for non-`naive` engines: rollout uses **`CheckpointEngineWorker`** (disaggregated GPUs). | + +```mermaid +graph TB + subgraph ring["verl native NIXLCheckpointEngine — ring (conceptual)"] + T["Trainer rank 0
send_weights only"] + R1["Rollout 1
READ prev → forward"] + R2["Rollout 2
READ prev → forward"] + RN["Rollout N
READ prev only"] + T --> R1 + R1 --> R2 + R2 --> RN + end + + style T fill:#533483,stroke:#e94560,color:#fff + style R1 fill:#2e7d32,stroke:#66bb6a,color:#fff + style R2 fill:#2e7d32,stroke:#66bb6a,color:#fff + style RN fill:#2e7d32,stroke:#66bb6a,color:#fff +``` + +Weights still flow **trainer → rollouts**; intermediate rollouts **repeat** the bucket so the last rank receives the full model without the trainer fanning out to everyone directly. + +### Optional MX layer: same NIXL moves, different control plane + +| Dimension | Native **`nixl`** (verl) | With **`mx`** | +|-----------|---------------------------|---------------| +| **Topology** | **Ring**: each bucket walks rank `0 → 1 → … → N` with optional forward between rollouts | **Star**: each rollout **READs directly** from the trainer’s bucket (trainer completes **N** READ completions — fan-out on the trainer NIC) | +| **Who owns “who do I read?”** | **Driver + Ray ranks**: `prev` / `next` from gathered `NixlAgentMetadata` | **MX Server + Redis** catalog: source identity, **version / step**, optional **worker_rank**, room for **multiple registered sources** | +| **Rendezvous across jobs / clusters** | Topology is **defined by this job’s rank graph** | Consumers **resolve** a source via **gRPC** to a stable catalog service, then attach with **NIXL** as today | +| **Extensibility** | New routing ideas are expressed in **verl driver / topology helpers** | Many policies can live in **catalog + client** without expanding core ring math for every deployment | + +#### What an MX catalog enables (additive) + +A Redis-backed MX catalog records **which processes publish which weight version** and the **NIXL metadata** needed to attach. That is **optional**; when you need it, it supports: + +- **Global view for load spreading** — Readers can ask the catalog **which source should serve this read** when **several replicas** expose the same version, instead of always following one fixed ring order. + +- **Load- and locality-aware steering** — Entries are **per-source identities**, so policy can prefer **less busy** or **network-closer** holders as the fleet grows, without each node probing the entire cluster. + +- **More than one read source over time** — Processes that have **finished** receiving a version can register as **additional sources**; the catalog can list **multiple holders** so bandwidth can spread through the pool (trainer plus peers). + +- **Publish / retire coordination** — A central record can track **in-flight reads** vs **writer reuse** of GPU pages so trainers retire a version before mutating shared buffers—coordination that is harder when each rank only knows ring neighbors. + +- **Retention with elastic membership** — A **cluster-wide** picture of which versions exist and where a **last readable copy** remains before unregister helps elastic scale-out / scale-in. + +- **Stable service names** — Trainer and rollout attach to **DNS-backed catalog endpoints** even when Ray placement or node pools change between runs. + +**Prefer native `nixl`** when a single Ray job, ring latency, and driver-wired `prev` / `next` are exactly what you want — nothing else required. + +**Consider `mx`** when you want **catalog-driven discovery**, **explicit versions**, **direct trainer→each-rollout** reads, or a path to **multi-source** and **policy-driven** routing **without** growing custom ring topology code in-tree for every site. + +Both backends use the same **`CheckpointEngine`** actor model and keep **bulk weight bytes off Ray**; both use **NIXL/RoCE** for the actual GPU transfers. + +--- + +## 3. Timing Diagram — One `update_weights` Step + +```mermaid +sequenceDiagram + participant D as Driver
(CheckpointEngineManager) + participant T as Trainer WorkerDict + participant CE_T as MxCheckpointEngine
(trainer) + participant MX as MX Server + participant CE_R as MxCheckpointEngine
(rollout) + participant R as CheckpointEngineWorker
(rollout) + participant V as vLLM ServerAdapter + + Note over T: optimizer.step() complete + + D->>T: ray.remote: update_weights(step=N) + D->>R: ray.remote: update_weights(step=N) + + par prepare() on both sides + T->>CE_T: prepare() + CE_T->>CE_T: allocate GPU send bucket
register with NIXL + CE_T->>MX: publish agent_meta + tensor layout + R->>CE_R: prepare() + CE_R->>CE_R: allocate GPU recv bucket
register with NIXL + CE_R->>MX: poll_for_source(model_name) + MX-->>CE_R: trainer agent_meta + end + + D->>D: build_topology()
rank 0 → all rollouts + + par init_process_group + CE_T->>CE_T: add rollout agents (NIXL) + CE_R->>CE_R: add trainer agent (NIXL) + CE_T-->>CE_R: ZMQ handshake (ip:port) + end + + loop per bucket (streamed) + T->>CE_T: send_weights(yield name, tensor) + CE_T->>CE_T: pack tensor into GPU bucket + CE_T->>CE_R: ZMQ: bucket desc (name, shape, dtype, offset) + CE_R->>CE_T: NIXL RDMA READ
(GPU→GPU, RoCE) + CE_R->>R: yield (name, tensor_view) + R->>V: load_weights([(name, tensor)]) + end + + par finalize + T->>CE_T: finalize() — deregister buffers + R->>CE_R: finalize() — deregister buffers + end + + Note over V: vLLM resumes with updated weights +``` + +**Observed per-step timing** (GB200, Qwen2.5-1.5B BF16, cross-node RoCE): + +| Phase | Wall time | +|-------|-----------| +| `prepare` + `build_topology` + `init_process_group` | ~0.3-0.4s | +| `send_weights` / `receive_weights` (RDMA) | ~0.6-0.8s | +| `finalize` | ~0.1s | +| **Total `update_weights`** | **~1.25s avg** (range 1.22-1.28s) | + +For the same model and cluster, the default `naive` engine averages **~1.6s** (in-process copy). The MX engine is faster *and* does a real cross-node transfer — the naive baseline only works because hybrid mode colocates trainer and rollout on the same GPUs. + +--- + +## 4. ModelExpress and the Ray Actor Design + +verl's runtime is a web of Ray actors. Understanding where `MxCheckpointEngine` lives inside that web is the key to the integration. + +### Actor topology + +```mermaid +graph TB + subgraph head["Ray Head · Node 1 · CPU driver"] + direction TB + tr["TaskRunner
(@ray.remote)"] + ceman["CheckpointEngineManager
(driver-side orchestrator)"] + tr --> ceman + end + + subgraph trainer_pg["Trainer Placement Group · Node 1 · 4 GPUs"] + direction TB + wd0["WorkerDict 0
ActorRolloutRefWorker
FSDP2 engine"] + wd1["WorkerDict 1"] + wd2["WorkerDict 2"] + wd3["WorkerDict 3"] + mxt["MxCheckpointEngine
instance (per worker)"] + wd0 --> mxt + end + + subgraph rollout_pg["Rollout Placement Group · Node 2 · 4 GPUs"] + direction TB + cew0["CheckpointEngineWorker 0
(@ray.remote)"] + cew1["CheckpointEngineWorker 1"] + cew2["CheckpointEngineWorker 2"] + cew3["CheckpointEngineWorker 3"] + mxr["MxCheckpointEngine
instance (per worker)"] + vllm["vLLM Server actor × 4
(ServerAdapter)"] + cew0 --> mxr + mxr --> vllm + end + + subgraph meta["MX Server Actor Group · CPU"] + mx["MX gRPC Server"] + rd[("Redis")] + mx --> rd + end + + ceman -- "ray.get
execute_checkpoint_engine(send)" --> wd0 + ceman -- "ray.get
execute_checkpoint_engine(recv)" --> cew0 + mxt -- "gRPC" --> mx + mxr -- "gRPC" --> mx + mxt <== "NIXL RDMA
(rank-paired)" ==> mxr + + style head fill:#1a1a2e,stroke:#533483,color:#e0e0e0 + style trainer_pg fill:#0f3460,stroke:#533483,color:#e0e0e0 + style rollout_pg fill:#0f3460,stroke:#533483,color:#e0e0e0 + style meta fill:#1a1a2e,stroke:#4caf50,color:#e0e0e0 + style tr fill:#533483,stroke:#e94560,color:#fff + style ceman fill:#533483,stroke:#e94560,color:#fff + style wd0 fill:#533483,stroke:#e94560,color:#fff + style wd1 fill:#162447,stroke:#533483,color:#e0e0e0 + style wd2 fill:#162447,stroke:#533483,color:#e0e0e0 + style wd3 fill:#162447,stroke:#533483,color:#e0e0e0 + style cew0 fill:#2e7d32,stroke:#66bb6a,color:#fff + style cew1 fill:#2e7d32,stroke:#66bb6a,color:#fff + style cew2 fill:#2e7d32,stroke:#66bb6a,color:#fff + style cew3 fill:#2e7d32,stroke:#66bb6a,color:#fff + style vllm fill:#533483,stroke:#e94560,color:#fff + style mxt fill:#1b5e20,stroke:#4caf50,color:#fff + style mxr fill:#1b5e20,stroke:#4caf50,color:#fff + style mx fill:#1b5e20,stroke:#4caf50,color:#fff + style rd fill:#162447,stroke:#533483,color:#e0e0e0 +``` + +### Three actor classes that matter + +1. **`TaskRunner`** — a single CPU Ray actor that owns the training loop. It holds the `CheckpointEngineManager` and drives PPO/GRPO iteration. +2. **`WorkerDict` (trainer side)** — a Ray GPU actor per trainer rank. Hosts the FSDP2 model, optimizer, and — under our integration — an `MxCheckpointEngine` instance for the trainer role. +3. **`CheckpointEngineWorker` (rollout side)** — a dedicated Ray GPU actor per rollout rank. It exists only in **standalone mode** (rollout on its own GPU pool). It hosts the `MxCheckpointEngine` in the rollout role and drives `ServerAdapter.load_weights` into the colocated vLLM engine. + +### Why standalone mode matters + +verl has two deployment modes for the rollout: + +| Mode | Ray actors | Status for MX | +|------|-----------|--------------| +| **Hybrid (colocated)** | `WorkerDict` does both training and rollout | **Not supported** — `WorkerDict` lacks an `execute_checkpoint_engine` method, so `CheckpointEngineManager` fails. | +| **Standalone (disaggregated)** | Trainer uses `WorkerDict`, rollout uses `CheckpointEngineWorker` | **Supported** — full CE lifecycle available. | + +This is a verl framework constraint, not an MX constraint — the built-in `nixl` and `nccl` engines have the same requirement. Our prototype runs in standalone mode on 2 nodes. + +### How a weight sync crosses the actor boundary + +```text +TaskRunner (Node 1) + └─► CheckpointEngineManager.update_weights(step=N) # driver-side + ├─► ray.get([wd0.execute_checkpoint_engine("prepare"), # fan-out to trainer + │ wd1.execute_checkpoint_engine("prepare"), + │ wd2.execute_checkpoint_engine("prepare"), + │ wd3.execute_checkpoint_engine("prepare")]) + ├─► ray.get([cew0.execute_checkpoint_engine("prepare"), # fan-out to rollout + │ cew1...cew3.execute_checkpoint_engine("prepare")]) + ├─► build_topology(agent_meta_list) # computed on driver + ├─► ray.get([.init_process_group(topology) on both sides]) + ├─► ray.get([wd0..3.send_weights(generator), cew0..3.receive_weights()]) # the RDMA moment + └─► ray.get([.finalize() on both sides]) +``` + +The `execute_checkpoint_engine("method", *args)` pattern is how the manager dispatches a named lifecycle call onto every CE-hosting actor in parallel. Every `ray.get` is a fan-out over 4 trainer + 4 rollout actors, so all 8 ranks move in lock-step. + +### Where `MxCheckpointEngine` is instantiated + +- On the **trainer**: `ActorRolloutRefWorker.init_model()` creates the engine when the config sets `actor_rollout_ref.rollout.checkpoint_engine.backend=mx`. The engine registers to handle `send_weights`. +- On the **rollout**: `CheckpointEngineWorker.__init__` constructs the same class (via the `CheckpointEngineRegistry`) but with `role="rollout"`. It registers to handle `receive_weights` and to drive `load_weights` into the ServerAdapter. + +One class, two roles, distinguished by which actor type instantiates it. Both sides talk to the same MX Server over gRPC. + +--- + +## 5. ModelExpress and the Checkpoint Engine + +verl's `CheckpointEngine` ABC is the small, well-defined plugin surface that makes MX a drop-in. + +### The ABC + +```python +class CheckpointEngine(ABC): + def prepare(self) -> dict: ... # allocate buffers, NIXL register + @classmethod + def build_topology(cls, trainer_meta, rollout_meta) -> tuple: ... + def init_process_group(self, topology, rank) -> None: ... + async def send_weights(self, weights_iter) -> None: ... # trainer + async def receive_weights(self) -> AsyncGenerator: ... # rollout + def finalize(self) -> None: ... +``` + +All six methods are implemented in `verl/checkpoint_engine/mx_checkpoint_engine.py` (461 lines). The class is registered with one decorator: + +```python +@CheckpointEngineRegistry.register("mx") +class MxCheckpointEngine(CheckpointEngine): + ... +``` + +…which makes `backend: "mx"` selectable from Hydra config. + +### Lifecycle responsibilities + +```mermaid +graph LR + subgraph life["MxCheckpointEngine lifecycle (one update_weights step)"] + direction LR + P["prepare()
alloc GPU bucket
NIXL register
MX publish/discover"] + B["build_topology()
star: rank 0 → all rollouts"] + I["init_process_group()
NIXL add_remote_agent
ZMQ handshake"] + S["send_weights()
pack bucket
push ZMQ desc
wait for RDMA"] + R["receive_weights()
pull ZMQ desc
NIXL RDMA READ
yield tensors"] + F["finalize()
dereg NIXL
close ZMQ"] + P --> B --> I + I --> S + I --> R + S --> F + R --> F + end + + style life fill:#1a1a2e,stroke:#4caf50,color:#e0e0e0 + style P fill:#1b5e20,stroke:#4caf50,color:#fff + style B fill:#1b5e20,stroke:#4caf50,color:#fff + style I fill:#1b5e20,stroke:#4caf50,color:#fff + style S fill:#1b5e20,stroke:#4caf50,color:#fff + style R fill:#1b5e20,stroke:#4caf50,color:#fff + style F fill:#1b5e20,stroke:#4caf50,color:#fff +``` + +### Method-by-method behavior + +| Method | Trainer side | Rollout side | +|--------|--------------|--------------| +| `prepare()` | Allocate registered GPU send bucket, register with NIXL agent, publish agent metadata + tensor layout to MX Server via gRPC | Allocate GPU recv bucket, register with NIXL, call `MxClient.poll_for_source(model_name)` to get the trainer's agent blob | +| `build_topology()` | Driver-side utility. Produces `(trainer_agent → [rollout_agents])` star mapping | Same (called on driver) | +| `init_process_group()` | For each rollout rank: `nixl_agent.add_remote_agent(rollout_meta)` | `nixl_agent.add_remote_agent(trainer_meta)`; ZMQ PULL socket bound on free port, advertised to trainer | +| `send_weights(iter)` | Consume `(name, tensor)` generator. Pack tensors into the GPU bucket at known offsets. Send `BucketDesc{name, shape, dtype, offset, nbytes}` over ZMQ PUSH. Block until rollout signals ACK | — | +| `receive_weights()` | — | Pull `BucketDesc` over ZMQ. Issue NIXL RDMA READ into the recv bucket at the given offset. Yield `(name, tensor_view)` to verl, which forwards to `ServerAdapter.load_weights` | +| `finalize()` | Deregister NIXL memory regions, close ZMQ, tell MX Server to retire this version | Same | + +### Why a bucket, not per-tensor RDMA + +Each NIXL RDMA transfer has a fixed latency overhead (~50-100µs). A 1.5B model has hundreds of parameters. Issuing per-tensor transfers would spend more time in transfer setup than in data movement. The engine packs tensors into a contiguous GPU bucket (up to a configured size, e.g. 256 MB) and issues one RDMA READ per bucket. The ZMQ channel carries a list of `BucketDesc` entries so the receiver can slice the bucket back into named tensors. + +This is the same pattern used by verl's built-in NIXL engine. MX adopts it for parity — and because it works. + +### Config + +```yaml +actor_rollout_ref: + rollout: + checkpoint_engine: + backend: mx + engine_kwargs: + mx_server_url: modelexpress-server.kavin.svc.cluster.local:8001 + model_name: Qwen/Qwen2.5-1.5B + bucket_size_mb: 256 # optional, default 256 + skip_sleep_wake: true # avoid vLLM multiproc sleep/wake crash on ARM64 +``` + +### How verl uses `MxCheckpointEngine` on the tensor path + +verl streams `(name, tensor)` pairs through the `CheckpointEngine` API. `MxCheckpointEngine` packs those into registered GPU buckets; per-bucket metadata carries **name, shape, dtype, and byte offset** so `receive_weights` can slice views and hand them to **`ServerAdapter.load_weights`** without an extra layout file. That is the same bucket pattern as verl’s native `nixl` engine; **`mx`** swaps ring wiring for **catalog + star** discovery as described in §2. + +--- + +## 6. Prototype on GB200 — Results + +### Cluster + +| Resource | Value | +|----------|-------| +| Platform | GKE DGXCloud, GB200 ARM64 | +| Nodes | 2 (trainer + rollout), `hostNetwork=true` | +| GPUs | 8 × GB200 (4 per node) | +| Node pools | `customer-gpu-w0e` (trainer), `customer-gpu-o7v` (rollout) | +| Fabric | RoCE v2, IMEX channels via DRA `compute-domain-channel` | +| UCX | v1.20.0 built from source, transports `self,sm,rc,cuda_copy,gdr_copy,tcp` | +| NIXL | 1.1.0 (main branch) | +| PyTorch | 2.6 + cu128 | +| vLLM | 0.18.1 (0.19.0 has an ARM64 multiproc `resource_tracker` bug) | +| Image | `nvcr.io/nvidian/dynamo-dev/verl-mx:latest` | + +### Deployment + +```text +Node 1 (gke-...-w0e-...-tz1d, IP 10.0.0.83) +├─ Ray head StatefulSet (verl-mx-head-0) +│ ├─ TaskRunner / CheckpointEngineManager +│ └─ 4× WorkerDict (FSDP2 trainers) + 4× MxCheckpointEngine +└─ 4× NIXL agent (rc_mlx5) + +Node 2 (gke-...-o7v-...-mflg, IP 10.0.15.225) +├─ Ray worker StatefulSet (verl-mx-worker-0) +│ ├─ 4× CheckpointEngineWorker + 4× MxCheckpointEngine +│ └─ 4× vLLM ServerAdapter +└─ 4× NIXL agent (rc_mlx5) + +MX Server + Redis (kavin namespace, reachable from both nodes over gRPC) +``` + +### What we observed + +- `[MX-DEBUG] Initializing 4 replicas in STANDALONE mode (worker_group=None)` — rollout running as dedicated CE workers, not fused WorkerDicts. +- `[MX-DEBUG] Standalone replicas ([STANDALONE x4]), using mx checkpoint engine` — repeated 11 times across the run (one per `update_weights`). +- `Backend UCX was instantiated` + `Initialized NIXL agent: ` on both nodes. +- UCX `rc_mlx5` transport negotiated — confirmed RoCE data path. +- Full lifecycle traced per step: `prepare → build_topology → init_process_group → send_weights / receive_weights → finalize`. + +### Per-step `update_weights` timing (MX engine, cross-node RDMA) + +| Step | `update_weights` (s) | +|------|---------------------| +| 1 | 1.278 | +| 2 | 1.250 | +| 3 | 1.233 | +| 4 | 1.252 | +| 5 | 1.243 | +| 6 | 1.223 | +| 7 | 1.235 | +| 8 | 1.249 | +| 9 | 1.263 | +| 10 | 1.282 | +| **Avg** | **≈ 1.25s** | + +### Headline metrics + +| Metric | Value | +|--------|-------| +| Model | Qwen/Qwen2.5-1.5B (BF16, ≈ 3 GB resident) | +| Steps completed | 10 (full PPO/GRPO run) | +| Avg step time | ~8.1-8.8s | +| Avg `update_weights` (MX) | **~1.25s** | +| Avg `update_weights` (naive baseline, hybrid mode) | ~1.6s | +| verl native `nixl` (same step, standalone) | *Not benchmarked in this doc* — same NIXL READ + buckets as **`mx`**, with **ring** topology and **no** MX Server | +| Throughput | 135-163 tokens/sec | +| Transport | NIXL / UCX `rc_mlx5` (RoCE RDMA) | +| Data path | Cross-node GPU→GPU (no CPU staging, no filesystem) | + +### Manual NIXL transfer test (isolated, inside one pod) + +Before wiring the engine into the training loop, we ran a standalone `MxTrainingPublisher` → `MxRefitReceiver` test to validate the MX/NIXL data plane by itself: + +| Metric | Value | +|--------|-------| +| Payload | 10 tensors × 2 MB = 21 MB | +| Transfer time | 0.16s | +| Data integrity | All tensors byte-verified correct | +| Environment | Loopback (same GPU, single pod) | + +This proved publisher/receiver correctness before moving to cross-node. + +### How we got there — build history + +21 Docker image iterations were needed to reach a working ARM64 build. Dominant issues and their fixes: + +| Category | Resolution | +|----------|-----------| +| NIXL build (missing tag, `pybind11`) | Clone NIXL `main`, add `pybind11-dev` to apt install | +| PyTorch CPU-only on ARM64 | Install `torch==2.6` with `--index-url` `download.pytorch.org/whl/cu128` *before* vLLM; reinstall matching version after | +| `flash_attn` absent on ARM64 | Ship `flash_attn_compat.py` in `verl/utils` — real SDPA fallbacks for GQA attention and cross-entropy | +| vLLM 0.19 multiproc crash | Downgrade to 0.18.1 (stable on ARM64) | +| Triton JIT — missing gcc / nvcc | Base runtime on `cuda:12.8.1-devel`, add `gcc/g++` | +| `cupy` not installed | Made optional in `MxCheckpointEngine`, added torch fallback for bucket allocation | +| Ray worker → head GCS | Use head's FQDN `verl-mx-head-0.verl-mx-head.kavin.svc.cluster.local:6379` | +| `CheckpointEngineManager` fails in hybrid mode | Deploy in standalone (2-node) mode — matches built-in NIXL/NCCL engine requirements | + +### What this proves + +1. **MX works on Ray.** The `CheckpointEngine` plugin surface is sufficient to express a star-topology RDMA transfer with server-mediated discovery. +2. **Cross-node RoCE RDMA is real.** `update_weights` at 1.25s for a 3 GB model on a 2-node GB200 cluster is consistent with UCX `rc_mlx5` over RoCE and beats the in-process naive baseline even before we tune the bucket size. +3. **The ARM64 path is painful but survivable.** All the image work is in `docker/Dockerfile.mx-arm64` and the compat shim, aligned with other GB200 MX + vLLM container iterations on the same stack. +4. **Standalone rollout is the production shape.** Hybrid/colocated mode is useful for debugging but cannot drive any non-naive checkpoint engine — true for NIXL, NCCL, and MX alike. + diff --git a/docs/RL/scripts/README.md b/docs/RL/scripts/README.md new file mode 100644 index 00000000..ce71366f --- /dev/null +++ b/docs/RL/scripts/README.md @@ -0,0 +1,105 @@ +# NIXL Compression Study — Capture Scripts + +Two scripts for producing the data described in [`../NIXL_COMPRESSION_STUDY.md`](../NIXL_COMPRESSION_STUDY.md). + +| Script | When to use | +|--------|-------------| +| `capture_weights_and_kv.py` | **Standalone** — capture from any HuggingFace model on any host. Doesn't require a running RL deployment. Just downloads the model and dumps weights + KV cache. CLI flags for model, dtype, device, output dir. | +| `capture_on_pod.py` | **Inside a running RL pod** — exec into a trainer pod and capture pre-step weights, simulate one AdamW step, capture post-step weights + delta + KV cache in one pass. Produces the four-directory layout we shipped to the NIXL team for Qwen2.5-1.5B. | + +## Standalone capture (any model, no cluster needed) + +```bash +pip install torch transformers safetensors + +# Smallest model — ~3 GB output, ~5 minutes total +python capture_weights_and_kv.py \ + --model Qwen/Qwen2.5-1.5B \ + --output-dir ./nixl_data \ + --dtype bfloat16 \ + --device cpu \ + --kv-seq-len 512 + +# Larger model +python capture_weights_and_kv.py \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --output-dir ./nixl_data \ + --dtype bfloat16 \ + --device cuda:0 \ + --kv-seq-len 2048 + +# Weights only / KV only +python capture_weights_and_kv.py --model --output-dir --weights-only +python capture_weights_and_kv.py --model --output-dir --kv-only +``` + +Output layout: + +```text +/ +├── weights// +│ ├── tensors/*.bin # one file per parameter, raw bytes (BF16) +│ └── manifest.json # name, shape, dtype, size, layer index, classification, stats +└── kvcache// + ├── layer_N_key.bin # one file per (layer, key/value) + ├── layer_N_value.bin + └── manifest.json +``` + +## Capture from a running pod (with RL-step simulation) + +This script is what produced the `RL_Qwen25/` package referenced in the NIXL request. It captures weights pre- and post- a simulated AdamW step, then computes the delta: + +```bash +# Copy the script into the pod +kubectl cp capture_on_pod.py /:/tmp/capture.py + +# Run it (uses /app/.venv/bin/python in our overlay image; adjust if different) +kubectl exec / -- /app/.venv/bin/python /tmp/capture.py \ + --model Qwen/Qwen2.5-1.5B \ + --out /tmp/nixl_capture \ + --kv-seq-len 512 \ + --lr 5e-6 + +# Copy results back +kubectl cp /:/tmp/nixl_capture ./RL_capture +``` + +Output layout (matches what we shipped to the NIXL team): + +```text +nixl_capture/ +├── weights_pre_rl/ # pre-step weight tensors + manifest.json +├── weights_post_rl/ # post-step weight tensors + manifest.json +├── weight_deltas/ # post - pre (BF16; mostly zero — see note below) +└── kv_cache/ # one prefill pass output + manifest.json +``` + +### Note on BF16 deltas + +A single AdamW step at `lr=5e-6` produces parameter updates of magnitude ~1e-8 to 1e-6, which is **below BF16's representable precision** at typical weight magnitudes (0.01–0.1). The `weight_deltas/` files will therefore be mostly zero in BF16. + +For meaningful delta-compression analysis, compute the diff in FP32 from the pre/post safetensors: + +```python +import torch +from safetensors import safe_open + +with safe_open("weights_pre_rl/...", framework="pt") as pre, \ + safe_open("weights_post_rl/...", framework="pt") as post: + for k in pre.keys(): + delta_fp32 = post.get_tensor(k).float() - pre.get_tensor(k).float() + if delta_fp32.abs().max() > 0: + print(f"{k}: max_abs_delta={delta_fp32.abs().max():.2e}") +``` + +The pre/post tensors are saved as raw BF16 (the on-the-wire dtype). The FP32 delta is the meaningful analysis target — this is the signal nvCOMP would compress in a delta-transfer scheme. + +## What gets captured + +See [`../NIXL_COMPRESSION_STUDY.md`](../NIXL_COMPRESSION_STUDY.md) for the full breakdown of: + +- Per-tensor layout (Qwen3-0.6B and Qwen2.5-1.5B examples) +- KV cache shape + scaling table +- Compression-relevant properties +- Where these tensors fit in the NIXL transfer path diff --git a/docs/RL/scripts/capture_on_pod.py b/docs/RL/scripts/capture_on_pod.py new file mode 100644 index 00000000..5e3e510b --- /dev/null +++ b/docs/RL/scripts/capture_on_pod.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +"""Capture weight and KV cache data from a running PRIME-RL / verl deployment. + +Designed to be exec'd inside a trainer pod. Captures four artifacts in the +same shape we shipped to the NIXL nvCOMP compression team for Qwen2.5-1.5B: + + weights_pre_rl/ raw .bin tensors + manifest.json (pre-step state dict) + weights_post_rl/ raw .bin tensors + manifest.json (after one AdamW step) + weight_deltas/ raw .bin tensors + manifest.json (post - pre) + kv_cache/ raw .bin tensors + manifest.json (one prefill pass) + +Then a final summary line tells you how to `kubectl cp` it out. + +Usage (inside pod): + python3 capture_on_pod.py + python3 capture_on_pod.py --model Qwen/Qwen2.5-7B --out /tmp/nixl_capture --kv-seq-len 1024 + +Usage (from host, no pod): + kubectl cp capture_on_pod.py /:/tmp/capture.py + kubectl exec / -- python3 /tmp/capture.py --model + kubectl cp /:/tmp/nixl_capture ./RL_capture +""" +import argparse, json, os, time, torch +from pathlib import Path +from transformers import AutoModelForCausalLM, AutoTokenizer + +parser = argparse.ArgumentParser() +parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B", + help="HuggingFace model name (must match the running RL deployment)") +parser.add_argument("--out", default="/tmp/nixl_capture", + help="Output directory inside the pod") +parser.add_argument("--kv-seq-len", type=int, default=512, + help="Sequence length for the KV cache prefill pass") +parser.add_argument("--lr", type=float, default=5e-6, + help="Learning rate for the simulated AdamW step (matches PRIME-RL default)") +args = parser.parse_args() + +MODEL = args.model +OUT = Path(args.out) +OUT.mkdir(parents=True, exist_ok=True) + +def tensor_stats(t): + ft = t.float() + return {"min": float(ft.min()), "max": float(ft.max()), "mean": float(ft.mean()), + "std": float(ft.std()), "abs_mean": float(ft.abs().mean()), + "zero_frac": float((t == 0).float().mean())} + +def classify(name): + for k, v in [("embed", "embedding"), ("lm_head", "lm_head"), ("norm", "norm"), + ("q_proj", "attn_q"), ("k_proj", "attn_k"), ("v_proj", "attn_v"), + ("o_proj", "attn_o"), ("gate_proj", "mlp_gate"), ("up_proj", "mlp_up"), + ("down_proj", "mlp_down")]: + if k in name: return v + return "other" + +def layer_idx(name): + parts = name.split(".") + for i, p in enumerate(parts): + if p == "layers" and i + 1 < len(parts) and parts[i+1].isdigit(): + return int(parts[i+1]) + return -1 + +print(f"Loading {MODEL}...") +tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True) +model.eval() + +# --- 1. Dump weights (pre-RL step) --- +print("\n=== Capturing pre-RL weights ===") +wdir = OUT / "weights_pre_rl" +wdir.mkdir(exist_ok=True) +manifest = {"model": MODEL, "dtype": "bfloat16", "capture": "pre_rl_weights", + "description": "Exact weight tensors transferred during RL refit (training->inference via NIXL RDMA)", + "tensors": []} +total = 0 +for name, param in model.named_parameters(): + t = param.data.contiguous() + raw = bytes(t.untyped_storage())[:t.numel() * t.element_size()] + fname = name.replace(".", "_") + ".bin" + (wdir / fname).write_bytes(raw) + manifest["tensors"].append({"name": name, "file": fname, "shape": list(t.shape), + "dtype": str(t.dtype), "size_bytes": len(raw), "numel": t.numel(), + "layer": layer_idx(name), "type": classify(name), "stats": tensor_stats(t)}) + total += len(raw) +manifest["total_bytes"] = total +manifest["total_gb"] = round(total / 1e9, 3) +manifest["num_tensors"] = len(manifest["tensors"]) +cfg = model.config +manifest["model_config"] = {"num_hidden_layers": cfg.num_hidden_layers, "hidden_size": cfg.hidden_size, + "intermediate_size": cfg.intermediate_size, "num_attention_heads": cfg.num_attention_heads, + "num_key_value_heads": cfg.num_key_value_heads, "vocab_size": cfg.vocab_size} +(wdir / "manifest.json").write_text(json.dumps(manifest, indent=2)) +print(f" {manifest['num_tensors']} tensors, {manifest['total_gb']} GB -> {wdir}") + +# --- 2. KV cache --- +print(f"\n=== Capturing KV cache (seq_len={args.kv_seq_len}) ===") +kvdir = OUT / "kv_cache" +kvdir.mkdir(exist_ok=True) +prompt = "The quick brown fox jumps over the lazy dog. " * (args.kv_seq_len // 10 + 1) +inputs = tokenizer(prompt, return_tensors="pt", max_length=args.kv_seq_len, truncation=True) +with torch.no_grad(): + outputs = model(**inputs, use_cache=True) +kv = outputs.past_key_values +kv_manifest = {"model": MODEL, "capture": "kv_cache_prefill", "seq_len": int(inputs["input_ids"].shape[1]), + "description": "KV cache from prefill pass - transferred prefill->decode via NIXL in disagg inference", + "tensors": []} +kv_total = 0 +for li, layer_kv in enumerate(kv): + for ki, kn in enumerate(["key", "value"]): + t = layer_kv[ki].contiguous() + raw = bytes(t.untyped_storage())[:t.numel() * t.element_size()] + fname = f"layer_{li}_{kn}.bin" + (kvdir / fname).write_bytes(raw) + kv_manifest["tensors"].append({"name": f"layer_{li}.{kn}", "file": fname, + "shape": list(t.shape), "dtype": str(t.dtype), "size_bytes": len(raw), + "layer": li, "kv_type": kn, "stats": tensor_stats(t)}) + kv_total += len(raw) +kv_manifest["total_bytes"] = kv_total +kv_manifest["total_mb"] = round(kv_total / 1e6, 3) +kv_manifest["kv_config"] = {"num_layers": len(kv), + "num_kv_heads": cfg.num_key_value_heads, + "head_dim": cfg.hidden_size // cfg.num_attention_heads} +(kvdir / "manifest.json").write_text(json.dumps(kv_manifest, indent=2)) +print(f" {len(kv_manifest['tensors'])} tensors, {kv_manifest['total_mb']} MB -> {kvdir}") + +# --- 3. Simulate one RL step and capture post-RL weights + delta --- +print(f"\n=== Simulating RL step (AdamW, lr={args.lr}, dummy loss) ===") +model.train() +optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01) +dummy_input = tokenizer("Hello world", return_tensors="pt") +output = model(**dummy_input, labels=dummy_input["input_ids"]) +output.loss.backward() +optimizer.step() +optimizer.zero_grad() +model.eval() + +print("\n=== Capturing post-RL weights ===") +wdir2 = OUT / "weights_post_rl" +wdir2.mkdir(exist_ok=True) +manifest2 = {"model": MODEL, "dtype": "bfloat16", "capture": "post_rl_weights", + "description": "Weights after 1 RL optimizer step (lr=5e-6, AdamW)", "tensors": []} +total2 = 0 +for name, param in model.named_parameters(): + t = param.data.contiguous() + raw = bytes(t.untyped_storage())[:t.numel() * t.element_size()] + fname = name.replace(".", "_") + ".bin" + (wdir2 / fname).write_bytes(raw) + manifest2["tensors"].append({"name": name, "file": fname, "shape": list(t.shape), + "dtype": str(t.dtype), "size_bytes": len(raw), "numel": t.numel(), + "layer": layer_idx(name), "type": classify(name), "stats": tensor_stats(t)}) + total2 += len(raw) +manifest2["total_bytes"] = total2 +manifest2["total_gb"] = round(total2 / 1e9, 3) +(wdir2 / "manifest.json").write_text(json.dumps(manifest2, indent=2)) +print(f" {len(manifest2['tensors'])} tensors, {manifest2['total_gb']} GB -> {wdir2}") + +# --- 4. Compute and save deltas --- +print("\n=== Computing weight deltas (post - pre) ===") +ddir = OUT / "weight_deltas" +ddir.mkdir(exist_ok=True) +delta_manifest = {"model": MODEL, "capture": "weight_delta_1_step", + "description": ( + f"Difference between weights after 1 RL step vs before. " + f"RL uses lr={args.lr} so deltas are tiny — at BF16 most are exactly zero " + f"(below mantissa precision). For meaningful delta-compression analysis, " + f"compute diffs in FP32 from the pre/post safetensors instead of using " + f"this BF16-stored delta directly." + ), + "tensors": []} +dtotal = 0 +pre_files = {m["name"]: m["file"] for m in manifest["tensors"]} +for info in manifest2["tensors"]: + pre_raw = (OUT / "weights_pre_rl" / pre_files[info["name"]]).read_bytes() + post_raw = (wdir2 / info["file"]).read_bytes() + pre_t = torch.frombuffer(bytearray(pre_raw), dtype=torch.bfloat16).reshape(info["shape"]) + post_t = torch.frombuffer(bytearray(post_raw), dtype=torch.bfloat16).reshape(info["shape"]) + delta = post_t - pre_t + delta_raw = bytes(delta.contiguous().untyped_storage())[:delta.numel() * delta.element_size()] + fname = "delta_" + info["file"] + (ddir / fname).write_bytes(delta_raw) + delta_manifest["tensors"].append({"name": info["name"], "file": fname, + "shape": info["shape"], "dtype": "bfloat16", "size_bytes": len(delta_raw), + "stats": tensor_stats(delta)}) + dtotal += len(delta_raw) +delta_manifest["total_bytes"] = dtotal +delta_manifest["total_gb"] = round(dtotal / 1e9, 3) +(ddir / "manifest.json").write_text(json.dumps(delta_manifest, indent=2)) +print(f" {len(delta_manifest['tensors'])} delta tensors, {delta_manifest['total_gb']} GB -> {ddir}") + +# --- Summary --- +print(f"\n=== DONE ===") +print(f"Output: {OUT}") +print(f" weights_pre_rl/ : {manifest['total_gb']} GB ({manifest['num_tensors']} tensors)") +print(f" weights_post_rl/ : {manifest2['total_gb']} GB") +print(f" weight_deltas/ : {delta_manifest['total_gb']} GB") +print(f" kv_cache/ : {kv_manifest['total_mb']} MB ({len(kv_manifest['tensors'])} tensors)") +print(f"\nTo copy out: kubectl cp /:{OUT} ./RL_capture") diff --git a/docs/RL/scripts/capture_weights_and_kv.py b/docs/RL/scripts/capture_weights_and_kv.py new file mode 100644 index 00000000..38c6d215 --- /dev/null +++ b/docs/RL/scripts/capture_weights_and_kv.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +"""Capture raw model weights and KV cache data for NIXL nvCOMP compression study. + +Outputs: + weights/{model_name}/tensors/*.bin — raw weight tensors (as transferred during RL refit) + weights/{model_name}/manifest.json — per-tensor metadata + kvcache/{model_name}/*.bin — KV cache tensors from a sample forward pass + kvcache/{model_name}/manifest.json — per-KV-tensor metadata + +Usage: + python capture_weights_and_kv.py --model Qwen/Qwen2.5-1.5B --output-dir ./nixl_data + python capture_weights_and_kv.py --model Qwen/Qwen2.5-1.5B --output-dir ./nixl_data --kv-only + python capture_weights_and_kv.py --model Qwen/Qwen2.5-1.5B --output-dir ./nixl_data --weights-only +""" + +import argparse +import json +import os +import time +from pathlib import Path + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def sanitize_name(name: str) -> str: + return name.replace("/", "_").replace(".", "_") + + +def tensor_stats(t: torch.Tensor) -> dict: + with torch.no_grad(): + ft = t.float() + return { + "min": float(ft.min()), + "max": float(ft.max()), + "mean": float(ft.mean()), + "std": float(ft.std()), + "abs_mean": float(ft.abs().mean()), + "zero_fraction": float((t == 0).float().mean()), + } + + +def classify_tensor(name: str) -> str: + if "embed" in name: + return "embedding" + if "lm_head" in name: + return "lm_head" + if "layernorm" in name or "norm" in name: + return "norm" + if "q_proj" in name: + return "attention_q" + if "k_proj" in name: + return "attention_k" + if "v_proj" in name: + return "attention_v" + if "o_proj" in name: + return "attention_o" + if "gate_proj" in name or "w1" in name: + return "mlp_gate" + if "up_proj" in name or "w3" in name: + return "mlp_up" + if "down_proj" in name or "w2" in name: + return "mlp_down" + return "other" + + +def get_layer_index(name: str) -> int: + parts = name.split(".") + for i, part in enumerate(parts): + if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): + return int(parts[i + 1]) + return -1 + + +def capture_weights(model, model_name: str, output_dir: Path, dtype_name: str): + """Dump all model weight tensors as raw binary files + manifest.""" + safe_name = sanitize_name(model_name) + weight_dir = output_dir / "weights" / safe_name / "tensors" + weight_dir.mkdir(parents=True, exist_ok=True) + + manifest = { + "model_name": model_name, + "dtype": dtype_name, + "capture_type": "model_weights_for_rl_refit", + "description": ( + "These are the exact weight tensors transferred from training GPUs to " + "inference GPUs during the RL refit (weight sync) phase. In RL post-training, " + "after each optimizer step, the full model state dict is gathered and sent to " + "the inference engine (vLLM). These tensors represent that payload." + ), + "tensors": [], + } + + total_bytes = 0 + for name, param in model.named_parameters(): + t = param.data.contiguous().cpu() + raw_bytes = bytes(t.untyped_storage())[:t.numel() * t.element_size()] + fname = sanitize_name(name) + ".bin" + (weight_dir / fname).write_bytes(raw_bytes) + + info = { + "name": name, + "file": f"tensors/{fname}", + "shape": list(t.shape), + "dtype": str(t.dtype), + "size_bytes": len(raw_bytes), + "numel": t.numel(), + "layer_index": get_layer_index(name), + "tensor_type": classify_tensor(name), + "stats": tensor_stats(t), + } + manifest["tensors"].append(info) + total_bytes += len(raw_bytes) + + manifest["total_tensors"] = len(manifest["tensors"]) + manifest["total_bytes"] = total_bytes + manifest["total_gb"] = round(total_bytes / 1e9, 3) + + config = model.config + manifest["model_config"] = { + "num_hidden_layers": getattr(config, "num_hidden_layers", None), + "hidden_size": getattr(config, "hidden_size", None), + "intermediate_size": getattr(config, "intermediate_size", None), + "num_attention_heads": getattr(config, "num_attention_heads", None), + "num_key_value_heads": getattr(config, "num_key_value_heads", None), + "vocab_size": getattr(config, "vocab_size", None), + "max_position_embeddings": getattr(config, "max_position_embeddings", None), + "architecture": config.architectures[0] if hasattr(config, "architectures") and config.architectures else None, + } + + manifest_path = output_dir / "weights" / safe_name / "manifest.json" + manifest_path.write_text(json.dumps(manifest, indent=2)) + print(f"Weights captured: {len(manifest['tensors'])} tensors, {manifest['total_gb']} GB → {weight_dir}") + return manifest + + +def capture_kvcache(model, tokenizer, model_name: str, output_dir: Path, seq_len: int = 512): + """Run a forward pass and capture the KV cache tensors.""" + safe_name = sanitize_name(model_name) + kv_dir = output_dir / "kvcache" / safe_name + kv_dir.mkdir(parents=True, exist_ok=True) + + prompt = "The quick brown fox jumps over the lazy dog. " * (seq_len // 10) + inputs = tokenizer(prompt, return_tensors="pt", max_length=seq_len, truncation=True) + + device = next(model.parameters()).device + inputs = {k: v.to(device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = model(**inputs, use_cache=True) + + past_kv = outputs.past_key_values + + manifest = { + "model_name": model_name, + "capture_type": "kv_cache_prefill_to_decode", + "description": ( + "These are the KV cache tensors produced during the prefill phase and " + "sent to decode workers in disaggregated inference (prefill/decode split). " + "In NIXL-based KV transfer, these are the exact tensors transferred " + "GPU-to-GPU between prefill and decode nodes." + ), + "sequence_length": int(inputs["input_ids"].shape[1]), + "batch_size": 1, + "tensors": [], + } + + total_bytes = 0 + for layer_idx, layer_kv in enumerate(past_kv): + for kv_idx, kv_name in enumerate(["key", "value"]): + t = layer_kv[kv_idx].contiguous().cpu() + raw_bytes = bytes(t.untyped_storage())[:t.numel() * t.element_size()] + fname = f"layer_{layer_idx}_{kv_name}.bin" + (kv_dir / fname).write_bytes(raw_bytes) + + info = { + "name": f"layer_{layer_idx}.{kv_name}", + "file": fname, + "shape": list(t.shape), + "dtype": str(t.dtype), + "size_bytes": len(raw_bytes), + "layer_index": layer_idx, + "kv_type": kv_name, + "stats": tensor_stats(t), + } + manifest["tensors"].append(info) + total_bytes += len(raw_bytes) + + manifest["total_tensors"] = len(manifest["tensors"]) + manifest["total_bytes"] = total_bytes + manifest["total_mb"] = round(total_bytes / 1e6, 3) + + config = model.config + manifest["kv_config"] = { + "num_layers": len(past_kv), + "num_kv_heads": getattr(config, "num_key_value_heads", getattr(config, "num_attention_heads", None)), + "head_dim": getattr(config, "hidden_size", 0) // getattr(config, "num_attention_heads", 1), + "kv_dtype": str(past_kv[0][0].dtype), + } + + manifest_path = kv_dir / "manifest.json" + manifest_path.write_text(json.dumps(manifest, indent=2)) + print(f"KV cache captured: {len(manifest['tensors'])} tensors, {manifest['total_mb']} MB → {kv_dir}") + return manifest + + +def main(): + parser = argparse.ArgumentParser(description="Capture weight and KV cache data for NIXL compression study") + parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-1.5B", help="HuggingFace model name") + parser.add_argument("--output-dir", type=str, default="./nixl_compression_data", help="Output directory") + parser.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"]) + parser.add_argument("--device", type=str, default="cpu", help="Device to load model on (cpu or cuda:0)") + parser.add_argument("--weights-only", action="store_true", help="Only capture weights, skip KV cache") + parser.add_argument("--kv-only", action="store_true", help="Only capture KV cache, skip weights") + parser.add_argument("--kv-seq-len", type=int, default=512, help="Sequence length for KV cache capture") + args = parser.parse_args() + + output_dir = Path(args.output_dir) + dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32} + dtype = dtype_map[args.dtype] + + print(f"Loading model: {args.model} (dtype={args.dtype}, device={args.device})") + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + args.model, torch_dtype=dtype, trust_remote_code=True, device_map=args.device, + ) + model.eval() + + if not args.kv_only: + print("\n=== Capturing model weights (RL refit payload) ===") + capture_weights(model, args.model, output_dir, args.dtype) + + if not args.weights_only: + print(f"\n=== Capturing KV cache (prefill→decode, seq_len={args.kv_seq_len}) ===") + capture_kvcache(model, tokenizer, args.model, output_dir, seq_len=args.kv_seq_len) + + print(f"\nDone. Data written to {output_dir}/") + print("Send the output directory to the NIXL compression team.") + + +if __name__ == "__main__": + main() diff --git a/docs/slides/diagram-architecture.svg b/docs/slides/diagram-architecture.svg new file mode 100644 index 00000000..9e77b4a4 --- /dev/null +++ b/docs/slides/diagram-architecture.svg @@ -0,0 +1,99 @@ + + + + + + + + + + MODELEXPRESS RL REFIT ARCHITECTURE + + + + TRAINING WORKERS + + + FSDP2 / Megatron-LM + + + WeightExtractor + + Gather params per bucket + + + MxTrainingPublisher + + Register tensors with NIXL + Publish metadata (gRPC) + + + NIXL Agent (per GPU) + UCX/RDMA + registration + + + + MX SERVER + + + gRPC P2P Coordination + + + Redis / K8s CRD + + + Version Tracking + + + Refit Coordination + + + + INFERENCE WORKERS + + + vLLM / SGLang + + + MxRefitReceiver + + Poll for new versions + Receive RDMA writes + + + NIXL Agent (per GPU) + + + Apply weights in-place + + + Resume rollout inference + + + + + gRPC + + + + gRPC + + + + + + + RDMA WRITE + + + + GPU VRAM GPU VRAM Bypasses CPU, disk, and collective overhead + + + Automatic fallback: RDMA CUDA IPC TCP + + diff --git a/docs/slides/diagram-component-stack.svg b/docs/slides/diagram-component-stack.svg new file mode 100644 index 00000000..5f23f3a0 --- /dev/null +++ b/docs/slides/diagram-component-stack.svg @@ -0,0 +1,148 @@ + + + + + + + COMPONENT ARCHITECTURE  NEW & EXISTING + + + + TRAINING WORKERS + + + RL FRAMEWORK + NeMo RL / verl / PRIME-RL + + + TRAINING BACKEND + FSDP2 / Megatron-LM + + + + + + WeightExtractor + Gather params per bucket (FSDP2 / Megatron) + NEW + + + + + + MxTrainingPublisher + Register tensors with NIXL + Publish metadata + step version (gRPC) + NEW + + + + + + ResharderPlugin + Gather-then-shard / Direct-match / Auto + NEW + + + + + + NIXL Agent (per GPU) + UCX backend RDMA WRITE capability + EXISTING + + + + MODELEXPRESS SERVER (RUST) + + + gRPC P2P Service + EXISTING + + + Redis / K8s CRD Metadata + EXISTING + + + Refit Version Tracking + NEW + + + Bucket Coordination + NEW + + + Heartbeat / Stale Reaper + EXISTING + + + p2p.proto (extended) + MODIFIED + + + + INFERENCE WORKERS + + + INFERENCE ENGINE + vLLM / SGLang + + + MxRefitReceiver + Poll for new weight versions + Coordinate with MX Server + NEW + + + NIXL Agent (per GPU) + UCX Pre-registered receive buffers + EXISTING + + + Apply Weights In-Place + FP8 quantization on receiver side + + + SyncPolicyController + Sync One-step-off Fully async + NEW + + + Resume Rollout Generation + + + + + + gRPC + + + + + gRPC + + + + + + + RDMA WRITE (per bucket) + + + + LEGEND + + New component + + Modified + + Existing (NIXL, gRPC, Redis) + + Metadata (gRPC) + + Data plane + diff --git a/docs/slides/diagram-framework-comparison.svg b/docs/slides/diagram-framework-comparison.svg new file mode 100644 index 00000000..4f3e256c --- /dev/null +++ b/docs/slides/diagram-framework-comparison.svg @@ -0,0 +1,112 @@ + + + + + + + + + + FRAMEWORK INTEGRATION COMPARISON + + + + NeMo RL + + + verl + + + PRIME-RL + + + + + Training Backend + + + DTensor / Megatron + + + FSDP / FSDP2 / Megatron + + + FSDP2 (EP/CP) + + + + Inference Backend + + + vLLM, SGLang + + + vLLM, SGLang, HF + + + vLLM + + + + Current Sync + + + ZMQ IPC / HTTP / NCCL + + + NCCL / CheckpointEngine + + + Filesystem + HTTP + + + + MX Insertion + + + refit_policy_ + generation() + + + CheckpointEngine + ABC (v0.7) + + + Orchestrator + relay_weights() + + + + Primary Benefit + + + RDMA replaces ZMQ/NCCL + + + Multi-node w/o filesystem + + + Eliminates disk I/O + + + + Orchestration + + + Ray actors + + + Ray (driver + workers) + + + Lightweight orchestrator + + + + Shared Abstraction: WeightSyncBackend + Framework-agnostic interface decoupling transport mechanism from sync policy + diff --git a/docs/slides/diagram-rl-loop-bottleneck.svg b/docs/slides/diagram-rl-loop-bottleneck.svg new file mode 100644 index 00000000..272801dd --- /dev/null +++ b/docs/slides/diagram-rl-loop-bottleneck.svg @@ -0,0 +1,82 @@ + + + + + + + + + + + + + + + + + + + ON-POLICY RL TRAINING LOOP + + + + ROLLOUT + Inference Engine + vLLM / SGLang + Generate trajectories + + + + + + + REWARD + Compute Rewards + Rule-based / Model RM + + + + + + + TRAINING + Policy Gradient Update + FSDP2 / Megatron-LM + GRPO / PPO / DAPO + + + + + + + WEIGHT SYNC (REFIT) + BOTTLENECK + + + ILLUSTRATIVE WALL-CLOCK BREAKDOWN + + + Rollout (40%) + + + Rew + + + Train (20%) + + + REFIT (30%) + + + + ▲ Up to 30-40% for 70B+ models + + + + Filesystem ~20s+ | NCCL ~10s | ZMQ IPC ~3-5s | MX RDMA ~5s (target for 70B) + + diff --git a/docs/slides/diagram-transfer-flow.svg b/docs/slides/diagram-transfer-flow.svg new file mode 100644 index 00000000..aead0cfe --- /dev/null +++ b/docs/slides/diagram-transfer-flow.svg @@ -0,0 +1,88 @@ + + + + + + + + + + RL REFIT TRANSFER FLOW (ONE TRAINING STEP) + + + + Training (rank k) + + + MX Server + + + Inference (rank k) + + + + + + + + + optimizer.step() + + + + WeightExtractor.extract() + + + + + PublishMetadata(step=N) + + + + + UpdateStatus(READY) + + + + serving rollout... + + + + + poll_for_update() + + + + + GetMetadata(rank k) + + + + add_remote_agent(NIXL blob) + + + + + + RDMA WRITE (bucket_0) + + + + apply_weights(bucket_0) + + + + repeat for bucket_1 ... bucket_N (bounded GPU memory) + + + + resume_inference() + + + + next training step + diff --git a/docs/slides/mx-rl-integration-slides.html b/docs/slides/mx-rl-integration-slides.html new file mode 100644 index 00000000..8b21942e --- /dev/null +++ b/docs/slides/mx-rl-integration-slides.html @@ -0,0 +1,790 @@ + + + + + +ModelExpress RL Weight Update Integration + + + + + + +
+ + +
+
+ NVIDIA NIXL + ModelExpress + RL Post-Training +
+

ModelExpress for
RL Weight Updates

+

+ Extending GPU-to-GPU RDMA transfers from inference scaling to the training→inference refit boundary in reinforcement learning post-training. +

+

+ Target frameworks: NeMo RL · verl · PRIME-RL +

+
April 2026 — Integration Design Overview
+
+ + +
+
The Problem
+

The Weight Sync Bottleneck

+ +

+ On-policy RL (GRPO, PPO, DAPO) alternates between rollout generation on inference GPUs and gradient updates on training GPUs. After every training step, updated weights must reach inference before the next rollout — this refit phase stalls both sides. +

+ +
+
Wall-clock time breakdown (illustrative)
+
+
Rollout
+
Rew
+
Train
+
REFIT
+
+
▲ Up to 30–40% of wall-clock for large models
+
+ +
Current refit latency (70B-class model, multi-node)
+
+
+
Filesystem (PRIME-RL)
+
~20s+
+
+
+
NCCL Broadcast (NeMo RL)
+
~10s
+
+
+
ZMQ IPC (NeMo RL, co-loc)
+
~3-5s
+
+
+
MX RDMA P2P (target)
+
~5s
+
+
+
+ + +
+
The Solution
+

ModelExpress for Training→Inference Refit

+ +

+ Extend MX from inference-to-inference P2P to the training→inference boundary. Training workers register updated weights with NIXL, publish metadata to the MX Server, and RDMA-WRITE directly into inference GPU memory — bypassing CPU, disk, and collective overheads. +

+ +
+
+
+
+
Training Workers
+
FSDP2 / Megatron
+
WeightExtractor
+
MxTrainingPublisher
+
NIXL Agent per GPU
+
+
+
+
MX Server
+
gRPC Metadata Coord
+
Redis / K8s CRD
+
Version Tracking
+
+
+
+
Inference Workers
+
vLLM / SGLang
+
MxRefitReceiver
+
NIXL Agent per GPU
+
+
+
+ RDMA WRITE — GPU VRAM → GPU VRAM — bypasses CPU & disk +
+
+
+ +
+
+
~5s
+
Llama-3.3-70B (140 GB)
+
+
+
~15s
+
DeepSeek-V3 MoE (681 GB)
+
+
+
0
+
CPU staging required
+
+
+
Auto
+
IPC → RDMA → TCP fallback
+
+
+
+ + +
+
Architecture
+

Component Deep-Dive

+ +
+ + + + TRAINING WORKERS + + + + RL FRAMEWORK + NeMo RL / verl / PRIME-RL + + + + TRAINING BACKEND + FSDP2 / Megatron-LM + + + + + + + + WeightExtractor + Gather params • Bucket iteration + NEW + + + + + + + + MxTrainingPublisher + Register tensors with NIXL + Publish metadata (gRPC) • Version tag + NEW + + + + + + + + ResharderPlugin + Gather-then-shard • Direct-match + NEW + + + + + + + + NIXL Agent + UCX backend • Per-GPU • RDMA WRITE + EXISTING + + + + MODELEXPRESS SERVER + + + gRPC P2P Service + EXISTING + + + Redis / K8s CRD Backend + EXISTING + + + Refit Version Tracking + NEW + + + Bucket Coordination + NEW + + + Heartbeat / Reaper + EXISTING + + + p2p.proto (extended) + MODIFIED + + + + INFERENCE WORKERS + + + + INFERENCE ENGINE + vLLM / SGLang + + + + MxRefitReceiver + Poll for new weight versions + Coordinate with MX Server + NEW + + + + NIXL Agent + UCX backend • Pre-registered buffers + EXISTING + + + + Apply Weights In-Place + FP8 quantization on receiver + + + + SyncPolicyController + Sync • One-step-off • Async + NEW + + + + Resume Rollout Generation + + + + + + gRPC + + + + + gRPC + + + + + + + RDMA WRITE (per bucket) + + + + LEGEND + + New component + + Modified + + Existing + + Metadata + + Data plane (RDMA) + +
+
+ + +
+
Integration Map
+

Three Frameworks, One Abstraction

+ +

+ A framework-agnostic WeightSyncBackend decouples the transfer mechanism from each framework's orchestration and sync policy. +

+ + + + + + + + + + + + + + + + + +
DimensionNeMo RLverlPRIME-RL
Training BackendDTensor / MegatronFSDP / FSDP2 / MegatronFSDP2 (EP/CP)
Inference BackendvLLM, SGLangvLLM, SGLang, HFvLLM
Current SyncZMQ IPC / HTTP / NCCLNCCL / CheckpointEngineFilesystem + HTTP
MX Insertion Pointrefit_policy_generation()CheckpointEngine ABCOrchestrator relay
Primary BenefitRDMA replaces ZMQ/NCCLMulti-node sans filesystemEliminates disk I/O
+ +
+
+

NeMo RL

+

New branch alongside ZMQ IPC and NCCL in refit function. Bucket-streamed transfer maps to MX publish. Ray actor integration.

+
+
+

verl

+

ModelExpressCheckpointEngine implements v0.7 CheckpointEngine ABC. Targets async server mode. Engine mode stays NCCL.

+
+
+

PRIME-RL

+

Replaces filesystem relay in orchestrator. For cross-DC, MX acts as fast intra-cluster delivery under SHARDCAST.

+
+
+
+ + +
+
Delivery
+

Phased Integration Plan

+ +
+
+
Phase 1 — Weeks 1-4
+

Foundation

+
    +
  • WeightExtractor abstraction (FSDP2 + Megatron)
  • +
  • MxTrainingPublisher with NIXL + gRPC
  • +
  • Proto extensions (RL_REFIT, training_step)
  • +
  • Single-node RDMA validation
  • +
  • Benchmark vs. ZMQ IPC baseline
  • +
+
+
+
Phase 2 — Weeks 5-10
+

Framework Integrations

+
    +
  • NeMo RL: MX branch in refit function
  • +
  • verl: ModelExpressCheckpointEngine
  • +
  • PRIME-RL: Orchestrator MX relay
  • +
  • Fallback to existing mechanisms
  • +
  • E2E GRPO/PPO correctness tests
  • +
+
+
+
Phase 3 — Weeks 11-13
+

Hardening

+
    +
  • ResharderPlugin (gather-then-shard)
  • +
  • Error handling + fallback paths
  • +
  • MoE bucket completion tracking
  • +
  • FP8 quantization validation
  • +
  • Multi-node benchmarks (70B, MoE)
  • +
+
+
+ +
+
+

Llama-3.1-8B

+

2 nodes · NCCL ~3s → MX ~1s

+
+
+

Llama-3.3-70B

+

4 nodes · ~10-20s → MX ~5s

+
+
+

DeepSeek-V3 MoE

+

8+ nodes · ~30s+ → MX ~15s

+
+
+ +
    +
  • Key risk: Parallelism layout mismatch — mitigated by ResharderPlugin and config alignment guidance
  • +
  • Dependency: InfiniBand for full RDMA benefit; clean fallback to NCCL/TCP/filesystem
  • +
+
+ +
+ +
01 / 06
+ + + + + + diff --git a/docs/slides/mx-rl-integration-slides.md b/docs/slides/mx-rl-integration-slides.md new file mode 100644 index 00000000..10ffb5d8 --- /dev/null +++ b/docs/slides/mx-rl-integration-slides.md @@ -0,0 +1,199 @@ +# ModelExpress for RL Weight Updates + +> **April 2026 — Integration Design Overview** +> +> `NVIDIA NIXL` · `ModelExpress` · `RL Post-Training` + +Extending GPU-to-GPU RDMA transfers from **inference scaling** to the **training→inference refit boundary** in reinforcement learning post-training. + +Target frameworks: **NeMo RL** · **verl** · **PRIME-RL** + +--- + +## Slide 1 — Title + +**ModelExpress for RL Weight Updates** + +Extending GPU-to-GPU RDMA transfers from **inference scaling** to the **training→inference refit boundary** in reinforcement learning post-training. + +Target frameworks: **NeMo RL** · **verl** · **PRIME-RL** + +--- + +## Slide 2 — The Problem: The Weight Sync Bottleneck + +On-policy RL (GRPO, PPO, DAPO) alternates between rollout generation on inference GPUs and gradient updates on training GPUs. After every training step, updated weights must reach inference before the next rollout — this **refit phase** stalls both sides. + +### Wall-clock time breakdown (illustrative) + +```text +| Rollout (40%) | Rew | Train (20%) | ██ REFIT (30%) ██ | + ▲ BOTTLENECK ▲ +``` + +> Up to 30–40% of wall-clock for 70B+ models + +### Current refit latency (70B-class model, multi-node) + +| Method | Latency | +|--------|---------| +| Filesystem (PRIME-RL) | ~20s+ | +| NCCL Broadcast (NeMo RL) | ~10s | +| ZMQ IPC (NeMo RL, co-located) | ~3-5s | +| **MX RDMA P2P (target)** | **~5s** | + +--- + +## Slide 3 — The Solution: ModelExpress for Training→Inference Refit + +Extend MX from inference-to-inference P2P to the training→inference boundary. Training workers register updated weights with NIXL, publish metadata to the MX Server, and RDMA-WRITE directly into inference GPU memory — bypassing CPU, disk, and collective overheads. + +### High-level data flow + +```text +Training Workers MX Server Inference Workers +(FSDP2 / Megatron) (gRPC + Redis/CRD) (vLLM / SGLang) + + WeightExtractor ──gRPC──► Metadata Coord ◄──gRPC── MxRefitReceiver + MxTrainingPublisher Version Tracking NIXL Agent + NIXL Agent + │ ▲ + └══════════════ RDMA WRITE (GPU→GPU) ════════════════┘ + bypasses CPU & disk +``` + +> *See: [diagram-architecture.svg](diagram-architecture.svg)* + +### Performance + +| Metric | Value | +|--------|-------| +| Llama-3.3-70B (140 GB) | **~5s** | +| DeepSeek-V3 MoE (681 GB) | **~15s** | +| CPU staging required | **0** | +| Transport fallback | **Auto**: IPC → RDMA → TCP | + +--- + +## Slide 4 — Architecture: Component Deep-Dive + +> *See: [diagram-component-stack.svg](diagram-component-stack.svg)* + +### Training Workers + +| Component | Status | Description | +|-----------|--------|-------------| +| RL Framework | Existing | NeMo RL / verl / PRIME-RL | +| Training Backend | Existing | FSDP2 / Megatron-LM | +| **WeightExtractor** | **NEW** | Gather params per bucket (FSDP2 / Megatron) | +| **MxTrainingPublisher** | **NEW** | Register tensors with NIXL, publish metadata (gRPC), version tag | +| **ResharderPlugin** | **NEW** | Gather-then-shard / Direct-match / Auto | +| NIXL Agent | Existing | UCX backend, per-GPU, RDMA WRITE | + +### ModelExpress Server (Rust) + +| Component | Status | Description | +|-----------|--------|-------------| +| gRPC P2P Service | Existing | PublishMetadata, ListSources, GetMetadata, UpdateStatus | +| Redis / K8s CRD Backend | Existing | Metadata persistence and HA | +| **Refit Version Tracking** | **NEW** | training_step in SourceIdentity; version-filtered ListSources | +| **Bucket Coordination** | **NEW** | RefitCoordination message for bucket-level progress | +| Heartbeat / Reaper | Existing | Stale source detection and GC | +| p2p.proto | **MODIFIED** | New enums, fields, messages for RL refit | + +### Inference Workers + +| Component | Status | Description | +|-----------|--------|-------------| +| Inference Engine | Existing | vLLM / SGLang | +| **MxRefitReceiver** | **NEW** | Poll for new weight versions, coordinate with MX Server | +| NIXL Agent | Existing | UCX backend, pre-registered receive buffers | +| Apply Weights In-Place | Existing | FP8 quantization on receiver side | +| **SyncPolicyController** | **NEW** | Sync / One-step-off / Fully async modes | +| Resume Rollout | Existing | Continue generation after refit | + +### Data paths + +- **Metadata (gRPC)**: Training → MX Server → Inference (dashed, control plane) +- **Data plane (RDMA WRITE)**: Training NIXL Agent → Inference NIXL Agent (per bucket) + +--- + +## Slide 5 — Integration Map: Three Frameworks, One Abstraction + +A framework-agnostic **WeightSyncBackend** decouples the transfer mechanism from each framework's orchestration and sync policy. + +> *See: [diagram-framework-comparison.svg](diagram-framework-comparison.svg)* + +### Comparison + +| Dimension | NeMo RL | verl | PRIME-RL | +|-----------|---------|------|----------| +| Training Backend | DTensor / Megatron | FSDP / FSDP2 / Megatron | FSDP2 (EP/CP) | +| Inference Backend | vLLM, SGLang | vLLM, SGLang, HF | vLLM | +| Current Sync | ZMQ IPC / HTTP / NCCL | NCCL / CheckpointEngine | Filesystem + HTTP | +| **MX Insertion Point** | `refit_policy_generation()` | `CheckpointEngine` ABC | Orchestrator `relay_weights()` | +| Primary Benefit | RDMA replaces ZMQ/NCCL | Multi-node sans filesystem | Eliminates disk I/O | + +### Per-framework summary + +**NeMo RL** — New branch alongside ZMQ IPC and NCCL in refit function. Bucket-streamed transfer maps to MX publish. Ray actor integration. + +**verl** — `ModelExpressCheckpointEngine` implements v0.7 `CheckpointEngine` ABC. Targets async server mode. Engine mode stays NCCL (already optimal co-located). + +**PRIME-RL** — Replaces filesystem relay in orchestrator. For cross-DC (Intellect-2), MX acts as fast intra-cluster delivery under SHARDCAST. + +--- + +## Slide 6 — Delivery: Phased Integration Plan + +### Phase 1 — Weeks 1-4: Foundation + +- WeightExtractor abstraction (FSDP2 + Megatron) +- MxTrainingPublisher with NIXL + gRPC +- Proto extensions (`RL_REFIT`, `training_step`) +- Single-node RDMA validation +- Benchmark vs. ZMQ IPC baseline + +### Phase 2 — Weeks 5-10: Framework Integrations + +- NeMo RL: MX branch in refit function +- verl: `ModelExpressCheckpointEngine` +- PRIME-RL: Orchestrator MX relay +- Fallback to existing mechanisms +- E2E GRPO/PPO correctness tests + +### Phase 3 — Weeks 11-13: Hardening + +- ResharderPlugin (gather-then-shard) +- Error handling + fallback paths +- MoE bucket completion tracking +- FP8 quantization validation +- Multi-node benchmarks (70B, MoE) + +### Target Performance + +| Model | Nodes | Current | MX Target | +|-------|-------|---------|-----------| +| Llama-3.1-8B | 2 | NCCL ~3s | **MX ~1s** | +| Llama-3.3-70B | 4 | ~10-20s | **MX ~5s** | +| DeepSeek-V3 MoE | 8+ | ~30s+ | **MX ~15s** | + +### Key risks + +- **Parallelism layout mismatch** — mitigated by ResharderPlugin and config alignment guidance +- **InfiniBand dependency** — clean fallback to NCCL/TCP/filesystem when unavailable + +--- + +## Diagrams + +All diagrams are available as standalone SVG files for embedding in external slide tools: + +| File | Description | +|------|-------------| +| [diagram-rl-loop-bottleneck.svg](diagram-rl-loop-bottleneck.svg) | RL training loop with refit bottleneck highlighted | +| [diagram-architecture.svg](diagram-architecture.svg) | Three-column architecture: Training → MX Server → Inference | +| [diagram-component-stack.svg](diagram-component-stack.svg) | Full component stack with NEW/MODIFIED/EXISTING tags | +| [diagram-transfer-flow.svg](diagram-transfer-flow.svg) | Sequence diagram: one refit step (publish → poll → RDMA WRITE → apply) | +| [diagram-framework-comparison.svg](diagram-framework-comparison.svg) | NeMo RL / verl / PRIME-RL comparison grid | diff --git a/modelexpress_client/python/benchmarks/README.md b/modelexpress_client/python/benchmarks/README.md new file mode 100644 index 00000000..b641737e --- /dev/null +++ b/modelexpress_client/python/benchmarks/README.md @@ -0,0 +1,147 @@ +# MX v2 benchmarks + +Transport-layer benchmarks for ModelExpress v2 + `MxWeightTransferEngine`. + +## What's measured + +- **Cold-start join latency** for receivers in an elastic deployment +- **Per-receive RDMA bandwidth** (GB/s) and tensor count +- **Discovery (control-plane) vs RDMA (data-plane)** latency split +- **Compile-target filter** behavior (accept / reject / back-compat) +- **Trainer egress savings under tree fan-out** (pipeline replication) + +These exercise the same v2 fat-client code paths vLLM hits through +`MxWeightTransferEngine`, so the numbers are representative. + +## Quick CPU smoke (no MX server / NIXL / GPUs needed) + +```bash +python bench_elastic_scaling.py --mode=cpu --scenario=tree_fanout \ + --num-receivers=4 --steps=3 +``` + +Useful for orchestrator-logic CI and developing scenarios offline. + +## Live runs (MX server + GPU + NIXL required) + +Single host, one trainer + 3 receivers, two refit cycles: + +```bash +export MX_SERVER_URL=modelexpress-server.kavin.svc.cluster.local:8001 +python bench_elastic_scaling.py \ + --scenario=elastic_scale \ + --num-receivers=3 --steps=2 \ + --num-tensors=64 --tensor-bytes=$((8*1024*1024)) \ + --join-interval=2.0 --step-interval=3.0 \ + --output=elastic.json +``` + +Compile-target safety net + back-compat demo (one trainer, three +receivers with different filters): + +```bash +python bench_elastic_scaling.py \ + --scenario=compile_target \ + --trainer-compile-target=cutlass_fp8 \ + --num-tensors=16 --tensor-bytes=$((4*1024*1024)) \ + --output=compile_target.json +``` + +Expected: `recv-match` accepts, `recv-mismatch` is rejected at +discovery (no RDMA cycles spent), `recv-no-filter` accepts (back-compat). + +Tree fan-out (newcomers pull from earlier receivers, not the trainer): + +```bash +python bench_elastic_scaling.py \ + --scenario=tree_fanout \ + --num-receivers=4 --steps=3 \ + --join-interval=2.0 --step-interval=4.0 \ + --num-tensors=64 --tensor-bytes=$((8*1024*1024)) \ + --output=tree_fanout.json +``` + +Expected: `fanout_factor > 1.0` — total bytes delivered exceeds trainer +egress because receivers 2..N pulled from already-loaded peers. + +## Cluster mode (Kubernetes — kavin namespace) + +A turnkey Job manifest at `k8s/bench-elastic.yaml` runs all three +scenarios in sequence and stashes the JSON outputs in `/results/` +inside the pod. A driver script at `run_cluster_bench.sh` wraps the +apply + wait + collect cycle: + +```bash +./run_cluster_bench.sh # runs all 3 scenarios, collects JSON +./run_cluster_bench.sh --watch # also tails the pod logs live +``` + +After completion, results land in `./results-/` with one +JSON file per scenario plus a printed summary. The manifest requests +5 GPUs (1 trainer + 4 receivers); adjust the `nvidia.com/gpu` request +in the manifest if your namespace has different quota. + +The image is pinned to `nvcr.io/nvidian/prime-rl:v0.5.2` in the +manifest by default; update the tag if you want a different +modelexpress build. The harness lives at +`modelexpress/benchmarks/bench_elastic_scaling.py` inside any image +that has this branch's modelexpress install. + +## Output schema + +`--output results.json` produces a machine-readable document: + +```json +{ + "scenario": "elastic_scale", + "config": { ... CLI args ... }, + "started_at": 1748567890.12, + "finished_at": 1748567945.41, + "wall_seconds": 55.29, + "trainer": { + "worker_id": "bench-trainer-r0", + "mx_source_id": "...", + "published_versions": [1, 2, 3], + "compile_target": "cutlass_fp8", + "total_published_bytes": 1342177280 + }, + "receivers": [ + { + "receiver_id": "recv-0", + "worker_rank": 0, + "join_latency_seconds": 0.41, + "compile_target_filter": null, + "cycles": [ + { + "version": 1, + "bytes_received": 134217728, + "rdma_seconds": 0.082, + "bandwidth_gbps": 13.1, + "discovery_seconds": 0.014, + "source_worker_rank": 0 + } + ] + } + ], + "derived": { + "trainer_egress_bytes": 402653184, + "total_delivered_bytes": 1207959552, + "scenario_specific": { + "fanout_factor": 3.0, + "trainer_egress_mb": 402.7, + "total_delivered_mb": 1208.0 + } + } +} +``` + +## Caveats + +- The harness expects a working MX server reachable at + `--mx-server-url`. Boot one in your namespace before running. +- "Live" mode requires NIXL + CUDA; `--mode=cpu` is the fallback. +- The trainer subprocess holds the source alive for a generous tail + past its last publish so late receivers can still discover it. If + you need long-running publishes, run the trainer separately and + point receivers at it via `--num-receivers=N --steps=0` (skips + trainer launch). (Roadmap.) diff --git a/modelexpress_client/python/benchmarks/__init__.py b/modelexpress_client/python/benchmarks/__init__.py new file mode 100644 index 00000000..31ecfdb9 --- /dev/null +++ b/modelexpress_client/python/benchmarks/__init__.py @@ -0,0 +1,7 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +"""Benchmark harnesses for ModelExpress v2 + MxWeightTransferEngine. + +See bench_elastic_scaling.py for the main entry point and README.md +for usage examples + reproducibility notes. +""" diff --git a/modelexpress_client/python/benchmarks/bench_elastic_scaling.py b/modelexpress_client/python/benchmarks/bench_elastic_scaling.py new file mode 100644 index 00000000..03bf7d19 --- /dev/null +++ b/modelexpress_client/python/benchmarks/bench_elastic_scaling.py @@ -0,0 +1,827 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Elastic-scale-up + compile-target + tree-fan-out benchmark for MX v2. + +This is a *transport-layer* benchmark — it spins up a synthetic trainer +that publishes tensors via :class:`MxV2TrainingPublisher`, then spawns +N receivers (each :class:`MxV2RefitReceiver` driven via +:class:`MxWeightTransferEngine`) and records: + + - cold-start join latency (time from receiver start to first successful + receive) + - per-cycle RDMA bandwidth (GB/s) and tensor count + - discovery (control-plane) latency vs RDMA (data-plane) latency + - compile-target filter behavior (accept / reject / back-compat-no-filter) + - trainer egress savings under tree fan-out (pipeline replication) + +The benchmark does NOT need vLLM, NCCL, or even ``transformers`` — +just the MX v2 fat clients and a live MX server. It does need a CUDA +device and NIXL for real RDMA numbers; with ``--mode=cpu`` it runs a +shape-only smoke test (no real transfer, useful for CI). + +Scenarios +--------- + +``elastic_scale`` + Trainer publishes a fixed model for ``--steps`` versions. Receivers + join staggered every ``--join-interval`` seconds. Per receiver we + record join latency and per-version bandwidth. + +``compile_target`` + Publisher tags bytes with ``compile_target="cutlass_fp8"``. Three + receivers run simultaneously: + + 1. ``filter=cutlass_fp8`` (matched) — accepts + 2. ``filter=deep_gemm_fp8`` (mismatched) — refuses BEFORE RDMA + 3. ``filter=None`` (back-compat) — accepts (no filter) + + Output proves the safety net + the back-compat property in one shot. + +``tree_fanout`` + Identical to ``elastic_scale`` but with ``publish_self_as_replica=True`` + on every receiver. Receivers 2..N can discover and pull from + receivers 1..N-1 instead of the trainer. We measure trainer egress + bytes vs total bytes received as the "fan-out factor". + +Usage +----- + +Single-host smoke (no GPUs, no MX server — exercises plumbing only):: + + python bench_elastic_scaling.py --mode=cpu --scenario=elastic_scale \\ + --num-receivers=3 --tensor-bytes=1048576 --num-tensors=4 + +Full cluster run (against an MX server in the kavin namespace):: + + python bench_elastic_scaling.py \\ + --mx-server-url=modelexpress-server.kavin.svc.cluster.local:8001 \\ + --scenario=elastic_scale --num-receivers=4 --steps=3 \\ + --join-interval=2.0 --num-tensors=64 --tensor-bytes=8388608 \\ + --output=results.json + +Outputs +------- + +A JSON document with the metrics blob (machine-readable) and a printed +human summary table. Pipe ``--output=results.json`` to capture and +compare across runs. +""" + +from __future__ import annotations + +import argparse +import json +import logging +import multiprocessing as mp +import os +import sys +import time +from dataclasses import asdict, dataclass, field +from typing import Any + +logger = logging.getLogger("bench_elastic_scaling") + + +# ---------------------------------------------------------------------------- +# Result schema +# ---------------------------------------------------------------------------- + + +@dataclass +class ReceiverCycleResult: + """One receive cycle worth of metrics.""" + + version: int + bytes_received: int = 0 + tensors_received: int = 0 + rdma_seconds: float = 0.0 + bandwidth_gbps: float = 0.0 + discovery_seconds: float = 0.0 + source_worker_rank: int | None = None + error: str | None = None + + +@dataclass +class ReceiverResult: + """All cycles for one receiver.""" + + receiver_id: str + worker_rank: int + started_at: float + first_receive_at: float | None = None + join_latency_seconds: float | None = None # = first_receive_at - started_at + compile_target_filter: list[str] | None = None + cycles: list[ReceiverCycleResult] = field(default_factory=list) + + def total_bytes(self) -> int: + return sum(c.bytes_received for c in self.cycles) + + def total_rdma_seconds(self) -> float: + return sum(c.rdma_seconds for c in self.cycles) + + def avg_bandwidth_gbps(self) -> float: + t = self.total_rdma_seconds() + return (self.total_bytes() * 8) / (t * 1e9) if t > 0 else 0.0 + + +@dataclass +class TrainerResult: + """Publisher-side stats.""" + + worker_id: str | None + mx_source_id: str | None + started_at: float + published_versions: list[int] = field(default_factory=list) + compile_target: str | None = None + total_published_bytes: int = 0 + + +@dataclass +class BenchResult: + scenario: str + config: dict[str, Any] + trainer: TrainerResult | None + receivers: list[ReceiverResult] + started_at: float + finished_at: float + + def trainer_egress_bytes(self) -> int: + """Bytes the trainer actually had to serve out. + + For matched-TP + non-fan-out: == sum(receiver.total_bytes()). + For tree-fan-out: < that, because newcomers pulled from peers. + We approximate this as "bytes received by receivers whose + source_worker_rank == the trainer's rank"; receivers that + pulled from replicas don't count. + + Since the worker_rank of the trainer is always 0 in this + harness, and same-rank-only is set, the receiver's + source_worker_rank == 0 means "pulled from trainer". Other + values are "pulled from a replica". + """ + egress = 0 + for r in self.receivers: + for c in r.cycles: + if c.source_worker_rank == 0: + egress += c.bytes_received + return egress + + def to_summary_table(self) -> str: + lines = [] + lines.append(f"== Scenario: {self.scenario} ==") + lines.append(f"Wall time: {self.finished_at - self.started_at:.2f}s") + if self.trainer is not None: + lines.append( + f"Trainer: {self.trainer.worker_id} versions={self.trainer.published_versions} " + f"bytes={self.trainer.total_published_bytes / 1e6:.1f} MB " + f"compile_target={self.trainer.compile_target}" + ) + lines.append("") + lines.append( + f"{'receiver':<20} {'filter':<18} {'join_s':>8} {'cycles':>6} " + f"{'bytes_MB':>10} {'avg_Gbps':>10} {'errors':>7}" + ) + for r in self.receivers: + errors = sum(1 for c in r.cycles if c.error) + filt = ( + ",".join(r.compile_target_filter) + if r.compile_target_filter + else "(none)" + ) + join_str = ( + f"{r.join_latency_seconds:.2f}" + if r.join_latency_seconds is not None + else "n/a" + ) + lines.append( + f"{r.receiver_id:<20} {filt:<18} {join_str:>8} " + f"{len(r.cycles):>6} " + f"{r.total_bytes() / 1e6:>10.1f} " + f"{r.avg_bandwidth_gbps():>10.2f} " + f"{errors:>7}" + ) + if self.scenario.removesuffix("_cpu_smoke") == "tree_fanout": + total = sum(r.total_bytes() for r in self.receivers) + egress = self.trainer_egress_bytes() + ratio = total / egress if egress > 0 else float("inf") + lines.append("") + lines.append( + f"Tree fan-out: trainer_egress={egress / 1e6:.1f} MB, " + f"total_delivered={total / 1e6:.1f} MB, fanout_factor={ratio:.2f}x" + ) + return "\n".join(lines) + + def to_json(self) -> dict[str, Any]: + return { + "scenario": self.scenario, + "config": self.config, + "started_at": self.started_at, + "finished_at": self.finished_at, + "wall_seconds": self.finished_at - self.started_at, + "trainer": asdict(self.trainer) if self.trainer else None, + "receivers": [asdict(r) for r in self.receivers], + "derived": { + "trainer_egress_bytes": self.trainer_egress_bytes(), + "total_delivered_bytes": sum(r.total_bytes() for r in self.receivers), + "scenario_specific": _scenario_derived(self), + }, + } + + +def _scenario_derived(b: BenchResult) -> dict[str, Any]: + """Per-scenario derived numbers that don't fit the generic schema. + + Accepts both ``"elastic_scale"`` and ``"elastic_scale_cpu_smoke"`` + style names so the CPU smoke path produces the same derived + metrics as live runs. + """ + name = b.scenario.removesuffix("_cpu_smoke") + if name == "elastic_scale": + latencies = [ + r.join_latency_seconds + for r in b.receivers + if r.join_latency_seconds is not None + ] + return { + "join_latency_p50": _p(latencies, 0.5), + "join_latency_p99": _p(latencies, 0.99), + } + if name == "compile_target": + verdicts: dict[str, str] = {} + for r in b.receivers: + ok = any(not c.error for c in r.cycles) + verdicts[r.receiver_id] = "accepted" if ok else "rejected" + return {"verdicts": verdicts} + if name == "tree_fanout": + total = sum(r.total_bytes() for r in b.receivers) + egress = b.trainer_egress_bytes() + return { + "fanout_factor": (total / egress) if egress > 0 else None, + "trainer_egress_mb": egress / 1e6, + "total_delivered_mb": total / 1e6, + } + return {} + + +def _p(values: list[float], q: float) -> float | None: + if not values: + return None + s = sorted(values) + idx = min(len(s) - 1, int(q * (len(s) - 1) + 0.5)) + return s[idx] + + +# ---------------------------------------------------------------------------- +# Trainer + receiver entry points (run as subprocesses) +# ---------------------------------------------------------------------------- + + +def _run_trainer( + *, + role: str, + mx_server_url: str, + model_name: str, + num_tensors: int, + tensor_bytes: int, + steps: int, + step_interval_s: float, + compile_target: str, + device_id: int, + result_path: str, +) -> None: + """Publisher subprocess entry point.""" + import torch + from modelexpress.nemo_rl_v2 import MxV2TrainingPublisher, TrainerWorldLayout + from modelexpress.shape_descriptors import COMPILE_TARGET_HF_RAW + + layout = TrainerWorldLayout(tp_world_size=1, pp_world_size=1, ep_world_size=1) + pub = MxV2TrainingPublisher( + agent_name="bench-trainer-r0", + device_id=device_id, + mx_server_url=mx_server_url, + worker_rank=0, + world_layout=layout, + ) + pub.initialize(model_name=model_name, dtype="bfloat16") + + # Build synthetic tensors once and reuse buffers across steps. + dtype = torch.bfloat16 + elem_size = torch.tensor([], dtype=dtype).element_size() + numel_per_tensor = max(1, tensor_bytes // elem_size) + device = torch.device(f"cuda:{device_id}") if torch.cuda.is_available() else torch.device("cpu") + tensors = { + f"layer{i}.weight": torch.randn(numel_per_tensor, dtype=dtype, device=device) + for i in range(num_tensors) + } + total_bytes = num_tensors * numel_per_tensor * elem_size + + result = TrainerResult( + worker_id=pub.worker_id, + mx_source_id=pub.mx_source_id, + started_at=time.time(), + compile_target=compile_target, + total_published_bytes=total_bytes * steps, + ) + for version in range(1, steps + 1): + for name, t in tensors.items(): + pub.add_tensor( + name=name, + tensor=t, + compile_target=compile_target, + compile_metadata={"benchmark": "elastic", "step": version}, + ) + pub.publish(version=version) + # Bump status to READY so receivers' list_sources() finds it. + # On the first publish this also starts the heartbeat; subsequent + # calls are idempotent (the publisher only re-registers if needed). + pub.mark_ready() + result.published_versions.append(version) + logger.info("trainer: published v=%d (%d tensors)", version, num_tensors) + if version < steps: + time.sleep(step_interval_s) + + # Hold the trainer alive long enough for late receivers to find us. + # The orchestrator signals shutdown by deleting the heartbeat lock + # file; for simplicity here we just sleep for a generous tail. + time.sleep(max(5.0, step_interval_s * steps)) + + pub.shutdown() + with open(result_path, "w") as f: + json.dump(asdict(result), f) + + +def _run_receiver( + *, + receiver_id: str, + worker_rank: int, + mx_server_url: str, + model_name: str, + device_id: int, + listen_port: int | None, + compile_target_filter: list[str] | None, + target_versions: list[int], + poll_interval_s: float, + cycle_timeout_s: float, + deadline_s: float, + publish_self_as_replica: bool, + result_path: str, +) -> None: + """Receiver subprocess entry point. + + Drives the v2 receiver via :class:`MxWeightTransferEngine` so we + exercise the actual adapter path that vLLM will use. + """ + os.environ.setdefault("MX_WEIGHT_TRANSFER_AUTOREGISTER", "0") + from modelexpress.vllm_weight_transfer import ( + MxInitInfo, + MxUpdateInfo, + MxWeightTransferEngine, + ) + + engine = MxWeightTransferEngine( + init_info=MxInitInfo( + mx_server_url=mx_server_url, + model_name=model_name, + worker_rank=worker_rank, + agent_name=f"bench-{receiver_id}", + device_id=device_id, + listen_port=listen_port, + publish_self_as_replica=publish_self_as_replica, + ) + ) + result = ReceiverResult( + receiver_id=receiver_id, + worker_rank=worker_rank, + started_at=time.time(), + compile_target_filter=compile_target_filter, + ) + captured: list[tuple[str, Any]] = [] + deadline = time.monotonic() + deadline_s + + for v in target_versions: + cycle = ReceiverCycleResult(version=v) + # Poll until a candidate source is observable. + while time.monotonic() < deadline: + try: + engine.receive_weights( + MxUpdateInfo( + version=v, + compile_target_filter=set(compile_target_filter) + if compile_target_filter + else None, + timeout_seconds=cycle_timeout_s, + ), + load_weights=captured.extend, + ) + stats = engine.last_transfer_stats + if stats is not None: + cycle.bytes_received = stats.bytes_received + cycle.tensors_received = stats.tensors_received + cycle.rdma_seconds = stats.elapsed_seconds + cycle.bandwidth_gbps = stats.bandwidth_gbps + cycle.source_worker_rank = stats.source_worker_rank + cycle.discovery_seconds = engine.last_discovery_seconds + if result.first_receive_at is None: + result.first_receive_at = time.time() + result.join_latency_seconds = ( + result.first_receive_at - result.started_at + ) + logger.info( + "%s: v=%d bytes=%.1fMB rdma=%.2fs %.1fGbps from_rank=%s", + receiver_id, + v, + cycle.bytes_received / 1e6, + cycle.rdma_seconds, + cycle.bandwidth_gbps, + cycle.source_worker_rank, + ) + break + except RuntimeError as e: + msg = str(e) + # Phase 3b safety net: filter rejection is a "decided" outcome, + # not a transient error. Record it and move on. + if "no source matches filters" in msg or "no covering source set" in msg: + cycle.error = msg + logger.info("%s: v=%d filter rejected: %s", receiver_id, v, msg) + break + # Otherwise the source isn't published yet — poll again. + time.sleep(poll_interval_s) + else: + cycle.error = "deadline exceeded" + result.cycles.append(cycle) + + with open(result_path, "w") as f: + json.dump(asdict(result), f) + + +# ---------------------------------------------------------------------------- +# Orchestrator +# ---------------------------------------------------------------------------- + + +def _spawn(target, kwargs: dict[str, Any]) -> mp.Process: + p = mp.Process(target=target, kwargs=kwargs, daemon=True) + p.start() + return p + + +def _load_result(path: str, cls): + with open(path) as f: + d = json.load(f) + if cls is TrainerResult: + return TrainerResult(**d) + # ReceiverResult — rehydrate nested cycles + cycles = [ReceiverCycleResult(**c) for c in d.pop("cycles", [])] + return ReceiverResult(cycles=cycles, **d) + + +def run_elastic_scale(args: argparse.Namespace) -> BenchResult: + """N receivers join staggered. Trainer publishes ``--steps`` versions. + + Each receiver is launched with a delay relative to the previous. + All receivers try to consume every published version (so cold + joiners back-fill). + """ + tmpdir = args.tmpdir + os.makedirs(tmpdir, exist_ok=True) + started = time.time() + + trainer_path = os.path.join(tmpdir, "trainer.json") + trainer_proc = _spawn( + _run_trainer, + dict( + role="trainer", + mx_server_url=args.mx_server_url, + model_name=args.model_name, + num_tensors=args.num_tensors, + tensor_bytes=args.tensor_bytes, + steps=args.steps, + step_interval_s=args.step_interval, + compile_target=args.trainer_compile_target, + device_id=0, + result_path=trainer_path, + ), + ) + time.sleep(args.trainer_warmup) + + receiver_procs = [] + target_versions = list(range(1, args.steps + 1)) + for i in range(args.num_receivers): + rid = f"recv-{i}" + rpath = os.path.join(tmpdir, f"{rid}.json") + receiver_procs.append( + ( + rpath, + _spawn( + _run_receiver, + dict( + receiver_id=rid, + worker_rank=0, # same-rank pull + mx_server_url=args.mx_server_url, + model_name=args.model_name, + device_id=0, + listen_port=None, + compile_target_filter=None, + target_versions=target_versions, + poll_interval_s=args.poll_interval, + cycle_timeout_s=args.cycle_timeout, + deadline_s=args.deadline, + publish_self_as_replica=False, + result_path=rpath, + ), + ), + ) + ) + time.sleep(args.join_interval) + + for _, p in receiver_procs: + p.join(timeout=args.deadline + 30) + trainer_proc.join(timeout=args.deadline + 60) + finished = time.time() + + trainer_result = _load_result(trainer_path, TrainerResult) + receivers = [_load_result(rp, ReceiverResult) for rp, _ in receiver_procs] + return BenchResult( + scenario="elastic_scale", + config=vars(args), + trainer=trainer_result, + receivers=receivers, + started_at=started, + finished_at=finished, + ) + + +def run_compile_target(args: argparse.Namespace) -> BenchResult: + """Trainer publishes with a fixed compile_target; three receivers + with different filters demonstrate accept / reject / back-compat.""" + tmpdir = args.tmpdir + os.makedirs(tmpdir, exist_ok=True) + started = time.time() + + trainer_path = os.path.join(tmpdir, "trainer.json") + trainer_proc = _spawn( + _run_trainer, + dict( + role="trainer", + mx_server_url=args.mx_server_url, + model_name=args.model_name, + num_tensors=args.num_tensors, + tensor_bytes=args.tensor_bytes, + steps=args.steps, + step_interval_s=args.step_interval, + compile_target=args.trainer_compile_target, # e.g. "cutlass_fp8" + device_id=0, + result_path=trainer_path, + ), + ) + time.sleep(args.trainer_warmup) + + # Three receivers running concurrently + scenarios = [ + ("recv-match", [args.trainer_compile_target]), + ("recv-mismatch", ["deep_gemm_fp8"]), + ("recv-no-filter", None), + ] + procs = [] + for rid, filt in scenarios: + rpath = os.path.join(tmpdir, f"{rid}.json") + procs.append( + ( + rpath, + _spawn( + _run_receiver, + dict( + receiver_id=rid, + worker_rank=0, + mx_server_url=args.mx_server_url, + model_name=args.model_name, + device_id=0, + listen_port=None, + compile_target_filter=filt, + target_versions=[1], # one cycle is enough to demo + poll_interval_s=args.poll_interval, + cycle_timeout_s=args.cycle_timeout, + deadline_s=args.deadline, + publish_self_as_replica=False, + result_path=rpath, + ), + ), + ) + ) + + for _, p in procs: + p.join(timeout=args.deadline + 30) + trainer_proc.join(timeout=args.deadline + 60) + finished = time.time() + + return BenchResult( + scenario="compile_target", + config=vars(args), + trainer=_load_result(trainer_path, TrainerResult), + receivers=[_load_result(rp, ReceiverResult) for rp, _ in procs], + started_at=started, + finished_at=finished, + ) + + +def run_tree_fanout(args: argparse.Namespace) -> BenchResult: + """Like elastic_scale but receivers also publish_self_as_replica. + + Newcomers can pull from earlier receivers, so we expect the + trainer's egress bytes to be << total delivered bytes. + """ + tmpdir = args.tmpdir + os.makedirs(tmpdir, exist_ok=True) + started = time.time() + + trainer_path = os.path.join(tmpdir, "trainer.json") + trainer_proc = _spawn( + _run_trainer, + dict( + role="trainer", + mx_server_url=args.mx_server_url, + model_name=args.model_name, + num_tensors=args.num_tensors, + tensor_bytes=args.tensor_bytes, + steps=args.steps, + step_interval_s=args.step_interval, + compile_target=args.trainer_compile_target, + device_id=0, + result_path=trainer_path, + ), + ) + time.sleep(args.trainer_warmup) + + procs = [] + target_versions = list(range(1, args.steps + 1)) + for i in range(args.num_receivers): + rid = f"recv-{i}" + rpath = os.path.join(tmpdir, f"{rid}.json") + procs.append( + ( + rpath, + _spawn( + _run_receiver, + dict( + receiver_id=rid, + worker_rank=0, + mx_server_url=args.mx_server_url, + model_name=args.model_name, + device_id=0, + listen_port=None, + compile_target_filter=None, + target_versions=target_versions, + poll_interval_s=args.poll_interval, + cycle_timeout_s=args.cycle_timeout, + deadline_s=args.deadline, + publish_self_as_replica=True, # the only diff + result_path=rpath, + ), + ), + ) + ) + time.sleep(args.join_interval) + + for _, p in procs: + p.join(timeout=args.deadline + 30) + trainer_proc.join(timeout=args.deadline + 60) + finished = time.time() + + return BenchResult( + scenario="tree_fanout", + config=vars(args), + trainer=_load_result(trainer_path, TrainerResult), + receivers=[_load_result(rp, ReceiverResult) for rp, _ in procs], + started_at=started, + finished_at=finished, + ) + + +# ---------------------------------------------------------------------------- +# CPU-only smoke mode — runs the orchestrator logic against stubs, exercises +# the harness without needing a server or RDMA. +# ---------------------------------------------------------------------------- + + +def run_cpu_smoke(args: argparse.Namespace) -> BenchResult: + """Drive the harness end-to-end with stubbed trainer/receivers. + + The trainer and receivers run in-process and just simulate the + metrics they would produce. This lets us validate the orchestrator + + result aggregation + summary table without a live MX server. + """ + started = time.time() + trainer = TrainerResult( + worker_id="bench-trainer-r0", + mx_source_id="abcd1234efgh5678", + started_at=started, + compile_target=args.trainer_compile_target, + published_versions=list(range(1, args.steps + 1)), + total_published_bytes=args.num_tensors * args.tensor_bytes * args.steps, + ) + receivers = [] + for i in range(args.num_receivers): + join_delay = i * args.join_interval + r = ReceiverResult( + receiver_id=f"recv-{i}", + worker_rank=0, + started_at=started + join_delay, + first_receive_at=started + join_delay + 0.05, + join_latency_seconds=0.05, + ) + for v in range(1, args.steps + 1): + r.cycles.append( + ReceiverCycleResult( + version=v, + bytes_received=args.num_tensors * args.tensor_bytes, + tensors_received=args.num_tensors, + rdma_seconds=0.1, + bandwidth_gbps=(args.num_tensors * args.tensor_bytes * 8) / (0.1 * 1e9), + discovery_seconds=0.01, + source_worker_rank=0, + ) + ) + receivers.append(r) + + return BenchResult( + scenario=f"{args.scenario}_cpu_smoke", + config=vars(args), + trainer=trainer, + receivers=receivers, + started_at=started, + finished_at=time.time(), + ) + + +# ---------------------------------------------------------------------------- +# CLI +# ---------------------------------------------------------------------------- + + +def _parse_args(argv: list[str] | None = None) -> argparse.Namespace: + p = argparse.ArgumentParser( + description="Elastic-scale + compile-target + tree-fan-out benchmark for MX v2", + ) + p.add_argument( + "--scenario", + choices=["elastic_scale", "compile_target", "tree_fanout"], + default="elastic_scale", + ) + p.add_argument( + "--mode", + choices=["live", "cpu"], + default="live", + help="'live' = real trainer + receivers via subprocesses (needs MX server + NIXL); " + "'cpu' = stubbed orchestrator-only smoke", + ) + p.add_argument("--mx-server-url", default=os.environ.get("MX_SERVER_URL", "localhost:8001")) + p.add_argument("--model-name", default="bench/synthetic-1.5B") + p.add_argument("--num-receivers", type=int, default=3) + p.add_argument("--num-tensors", type=int, default=8) + p.add_argument("--tensor-bytes", type=int, default=8 * 1024 * 1024, + help="Bytes per tensor on the publisher side (default 8 MiB).") + p.add_argument("--steps", type=int, default=2) + p.add_argument("--step-interval", type=float, default=2.0) + p.add_argument("--join-interval", type=float, default=2.0, + help="Seconds between successive receiver starts.") + p.add_argument("--trainer-warmup", type=float, default=2.0, + help="Seconds to wait after starting the trainer before launching receivers.") + p.add_argument("--poll-interval", type=float, default=0.5) + p.add_argument("--cycle-timeout", type=float, default=60.0) + p.add_argument("--deadline", type=float, default=180.0) + p.add_argument("--trainer-compile-target", default="cutlass_fp8") + p.add_argument("--tmpdir", default="/tmp/mx_bench") + p.add_argument("--output", default=None, help="If set, also write JSON results to this path.") + p.add_argument("-v", "--verbose", action="store_true") + return p.parse_args(argv) + + +def main(argv: list[str] | None = None) -> int: + args = _parse_args(argv) + logging.basicConfig( + level=logging.DEBUG if args.verbose else logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + ) + + if args.mode == "cpu": + result = run_cpu_smoke(args) + else: + dispatch = { + "elastic_scale": run_elastic_scale, + "compile_target": run_compile_target, + "tree_fanout": run_tree_fanout, + } + result = dispatch[args.scenario](args) + + print(result.to_summary_table()) + if args.output: + with open(args.output, "w") as f: + json.dump(result.to_json(), f, indent=2, default=str) + print(f"\nFull JSON written to {args.output}") + return 0 + + +if __name__ == "__main__": + sys.exit(main()) diff --git a/modelexpress_client/python/benchmarks/k8s/bench-elastic.yaml b/modelexpress_client/python/benchmarks/k8s/bench-elastic.yaml new file mode 100644 index 00000000..482ea85d --- /dev/null +++ b/modelexpress_client/python/benchmarks/k8s/bench-elastic.yaml @@ -0,0 +1,145 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Kubernetes Job for the MX v2 transport-layer benchmark. +# +# Runs the bench_elastic_scaling.py harness inside the `kavin` +# namespace against the modelexpress-server deployment that already +# lives in that namespace. One pod runs all three scenarios in +# sequence — the harness handles its own trainer + receiver subprocess +# orchestration internally. +# +# Apply: +# kubectl apply -f bench-elastic.yaml +# +# Watch: +# kubectl -n kavin logs -f job/mx-bench-elastic +# +# Collect results once the Job completes: +# kubectl -n kavin cp $(kubectl -n kavin get pod -l job-name=mx-bench-elastic -o name | head -1 | cut -d/ -f2):/results /tmp/mx-bench-results +# +# Reuse the prime-rl image you're running in the namespace — that's +# the cleanest way to get NIXL + modelexpress + CUDA all aligned. +# Override IMAGE below if you want to pin a specific build. + +apiVersion: batch/v1 +kind: Job +metadata: + name: mx-bench-elastic + namespace: kavin + annotations: + # Auto-clean the namespace assignment after 4 hours so we don't + # hog GPUs after the run. + nscleanup/ttl: "4h" +spec: + ttlSecondsAfterFinished: 86400 # keep logs for 1 day + backoffLimit: 0 + template: + metadata: + labels: + app: mx-bench-elastic + spec: + restartPolicy: Never + # Pin to the same nodeSelector you use for inference workers. + # Replace with your environment's selector if different. + nodeSelector: + cloud.google.com/gke-accelerator: nvidia-gb200 + tolerations: + - key: nvidia.com/gpu + operator: Exists + effect: NoSchedule + containers: + - name: bench + # Pin to the prime-rl image that has modelexpress + NIXL + # installed. Update the tag as needed. + image: nvcr.io/nvidian/prime-rl:v0.5.2 + imagePullPolicy: IfNotPresent + command: ["/bin/bash", "-lc"] + args: + - | + set -euo pipefail + mkdir -p /results + cd /app + + export MX_SERVER_URL=modelexpress-server.kavin.svc.cluster.local:8001 + export PYTHONUNBUFFERED=1 + # Disable vLLM auto-registration since this Job doesn't + # spin up an LLM — it just exercises the transport. + export MX_WEIGHT_TRANSFER_AUTOREGISTER=0 + + # Locate the harness. After the v2 branch is installed, + # it lives under modelexpress/benchmarks/. + HARNESS=$(python -c 'import modelexpress, os; print(os.path.join(os.path.dirname(modelexpress.__file__), "benchmarks", "bench_elastic_scaling.py"))') + echo "Using harness at: $HARNESS" + + echo + echo "============================================" + echo "=== Scenario 1 / 3: elastic_scale" + echo "============================================" + python "$HARNESS" \ + --mx-server-url=$MX_SERVER_URL \ + --scenario=elastic_scale \ + --num-receivers=4 --steps=3 \ + --join-interval=3.0 --step-interval=4.0 \ + --num-tensors=64 --tensor-bytes=$((8*1024*1024)) \ + --deadline=240 \ + --output=/results/elastic_scale.json + + echo + echo "============================================" + echo "=== Scenario 2 / 3: compile_target" + echo "============================================" + python "$HARNESS" \ + --mx-server-url=$MX_SERVER_URL \ + --scenario=compile_target \ + --trainer-compile-target=cutlass_fp8 \ + --num-tensors=16 --tensor-bytes=$((4*1024*1024)) \ + --deadline=120 \ + --output=/results/compile_target.json + + echo + echo "============================================" + echo "=== Scenario 3 / 3: tree_fanout" + echo "============================================" + python "$HARNESS" \ + --mx-server-url=$MX_SERVER_URL \ + --scenario=tree_fanout \ + --num-receivers=4 --steps=3 \ + --join-interval=3.0 --step-interval=4.0 \ + --num-tensors=64 --tensor-bytes=$((8*1024*1024)) \ + --deadline=240 \ + --output=/results/tree_fanout.json + + echo + echo "All scenarios complete. Results under /results:" + ls -la /results + echo + echo "Summary (compile_target verdicts):" + python -c "import json; d=json.load(open('/results/compile_target.json')); print(json.dumps(d['derived']['scenario_specific'], indent=2))" + echo + echo "Summary (tree_fanout factor):" + python -c "import json; d=json.load(open('/results/tree_fanout.json')); print(json.dumps(d['derived']['scenario_specific'], indent=2))" + + resources: + requests: + # The harness uses one GPU per subprocess. With trainer + # + 4 receivers in scenarios 1 and 3, request 5. + nvidia.com/gpu: "5" + memory: "32Gi" + cpu: "8" + limits: + nvidia.com/gpu: "5" + memory: "64Gi" + cpu: "16" + volumeMounts: + - mountPath: /results + name: results + - mountPath: /dev/shm + name: dshm + volumes: + - name: results + emptyDir: {} + - name: dshm + emptyDir: + medium: Memory + sizeLimit: 8Gi diff --git a/modelexpress_client/python/benchmarks/run_cluster_bench.sh b/modelexpress_client/python/benchmarks/run_cluster_bench.sh new file mode 100755 index 00000000..837f830d --- /dev/null +++ b/modelexpress_client/python/benchmarks/run_cluster_bench.sh @@ -0,0 +1,93 @@ +#!/usr/bin/env bash +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Drive the MX v2 benchmark Job on the kavin cluster + collect results +# into ./results-/. +# +# Prerequisites: +# - kubectl pointed at the right cluster + context +# - `kavin` namespace has modelexpress-server running +# - prime-rl image (or any image with modelexpress installed) is +# reachable from the namespace's nodeSelector +# +# Usage: +# ./run_cluster_bench.sh # runs all 3 scenarios, collects JSON +# ./run_cluster_bench.sh --watch # also tail logs while it runs + +set -euo pipefail + +HERE="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")" && pwd)" +MANIFEST="$HERE/k8s/bench-elastic.yaml" +NS="kavin" +JOB="mx-bench-elastic" + +WATCH="" +if [[ "${1:-}" == "--watch" ]]; then + WATCH=1 +fi + +echo "[1/4] Cleaning up any prior Job..." +kubectl -n "$NS" delete job "$JOB" --ignore-not-found=true + +echo "[2/4] Applying $MANIFEST..." +kubectl apply -f "$MANIFEST" + +echo "[3/4] Waiting for pod to start..." +for i in $(seq 1 60); do + POD=$(kubectl -n "$NS" get pod -l job-name="$JOB" -o name 2>/dev/null | head -1 || true) + if [[ -n "$POD" ]]; then + echo " pod: $POD" + break + fi + sleep 2 +done +if [[ -z "$POD" ]]; then + echo "ERROR: pod did not appear within 120s" + exit 1 +fi + +if [[ -n "$WATCH" ]]; then + echo "Tailing logs (Ctrl-C to detach; the Job continues)..." + kubectl -n "$NS" logs -f "$POD" || true +fi + +echo "[4/4] Waiting for Job to complete..." +kubectl -n "$NS" wait --for=condition=complete --timeout=30m "job/$JOB" || { + echo "Job didn't complete in 30m. Final state:" + kubectl -n "$NS" describe job "$JOB" | tail -30 + echo + echo "Last log lines:" + kubectl -n "$NS" logs "$POD" --tail=80 || true + exit 1 +} + +TS=$(date +%Y%m%d-%H%M%S) +OUT="$HERE/results-$TS" +mkdir -p "$OUT" +echo "Collecting results into $OUT/..." +kubectl -n "$NS" cp "${POD#pod/}:/results" "$OUT/" || { + echo "WARN: kubectl cp failed; pulling files individually..." + for scen in elastic_scale compile_target tree_fanout; do + kubectl -n "$NS" exec "$POD" -- cat "/results/$scen.json" > "$OUT/$scen.json" || true + done +} +echo +echo "Done. Files:" +ls -la "$OUT" +echo +echo "Summary:" +for scen in elastic_scale compile_target tree_fanout; do + if [[ -f "$OUT/$scen.json" ]]; then + echo " $scen:" + python3 -c " +import json +d = json.load(open('$OUT/$scen.json')) +print(' wall_seconds:', round(d['wall_seconds'], 2)) +print(' derived:', json.dumps(d['derived']['scenario_specific'], indent=6).replace('\\n', '\\n ')) +" + fi +done +echo +echo "To populate pensieve/RL/PrimeRL/11_benchmark_results.md, paste the" +echo "per-receiver tables from these JSON files into the matching sections." diff --git a/modelexpress_client/python/modelexpress/__init__.py b/modelexpress_client/python/modelexpress/__init__.py index 52cd6d56..7533a294 100644 --- a/modelexpress_client/python/modelexpress/__init__.py +++ b/modelexpress_client/python/modelexpress/__init__.py @@ -73,12 +73,49 @@ def register_modelexpress_loaders(): from .gds_loader import MxGdsLoader # noqa: F401 from .gds_transfer import GdsTransferManager # noqa: F401 from .metadata.heartbeat import HeartbeatThread # noqa: F401 +from .nemo_rl_v2 import ( # noqa: F401 + MxV2RefitReceiver, + MxV2TrainingPublisher, + SliceCoveragePlan, + SliceSource, + TargetTPLayout, + TrainerWorldLayout, + V2SourceCandidate, +) +from .shape_descriptors import ( # noqa: F401 + COMPILE_TARGET_CUTLASS_FP8, + COMPILE_TARGET_DEEPGEMM_FP8, + COMPILE_TARGET_HF_RAW, + COMPILE_TARGET_TRTLLM, + COMPILE_TARGET_VLLM_FUSED, + TensorDescriptorV2, + compile_target_matches, +) +from .training_publisher import MxTrainingPublisher # noqa: F401 +from .refit_receiver import MxRefitReceiver, TransferStats # noqa: F401 __all__ = [ + "COMPILE_TARGET_CUTLASS_FP8", + "COMPILE_TARGET_DEEPGEMM_FP8", + "COMPILE_TARGET_HF_RAW", + "COMPILE_TARGET_TRTLLM", + "COMPILE_TARGET_VLLM_FUSED", "GdsTransferManager", "HeartbeatThread", "MxClient", "MxGdsLoader", + "MxRefitReceiver", + "MxTrainingPublisher", + "MxV2RefitReceiver", + "MxV2TrainingPublisher", + "SliceCoveragePlan", + "SliceSource", + "TargetTPLayout", + "TensorDescriptorV2", + "TrainerWorldLayout", + "TransferStats", + "V2SourceCandidate", + "compile_target_matches", "configure_vllm_logging", "register_modelexpress_loaders", ] diff --git a/modelexpress_client/python/modelexpress/nemo_rl_v2.py b/modelexpress_client/python/modelexpress/nemo_rl_v2.py new file mode 100644 index 00000000..20819e9a --- /dev/null +++ b/modelexpress_client/python/modelexpress/nemo_rl_v2.py @@ -0,0 +1,1296 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""V2 NemoRL helpers built on top of MxTrainingPublisher / MxRefitReceiver. + +This module implements the design from +``pensieve/RL/NemoRL/04_design_v2_moe_rank_to_rank.md`` as a Python-only +shim that doesn't require proto/Rust changes. The shim: + +1. Encodes per-tensor shape + placement + expert metadata into + ``SourceIdentity.extra_parameters`` (JSON document under key + ``shape_registry``). See :mod:`modelexpress.shape_descriptors`. + +2. Defaults to **same-rank-only transfers** (lesson from PrimeRL on + GB200; cross-subnet full-mesh fails on multi-NIC fabrics). Each + inference rank N pulls only from trainer rank N (or another + inference rank N that's already received via tree fan-out). + +3. Implements **tree fan-out / pipeline replication** by having + inference receivers republish themselves with NIXL after + receiving — subsequent receivers can pull from them. Source + selection prefers the trainer first, then any peer that's + ahead of us at the same ``worker_rank``. + +4. Encodes **owned / needed expert IDs** into ``extra_parameters`` + so a receiver in EP mode can skip non-owned experts entirely. + +5. Wraps :class:`HeartbeatThread` so v2 publishers / receivers come + with liveness signaling out of the box. The MX-side reaper can + correctly distinguish quiet-but-alive workers from dead ones. + +This is a **prototype-grade** shim: the eventual production answer is +new RPCs (PickSource, GetShapeRegistry, SetDirtyExperts, ...) on the +MX server, with full TopologyScheduler logic in Rust. See +``pensieve/RL/NemoRL/05_mx_helpers_needed.md`` for the proto migration. +""" + +from __future__ import annotations + +import json +import logging +from dataclasses import dataclass +from typing import Iterator + +import torch + +from . import p2p_pb2 +from .metadata.heartbeat import HeartbeatThread +from .refit_receiver import MxRefitReceiver, SourceRef +from .shape_descriptors import ( + COMPILE_TARGET_HF_RAW, + PLACEMENT_SHARD, + TensorDescriptorV2, + compile_target_matches, + decode_expert_set, + decode_registry, + describe_tensor, + encode_expert_set, + encode_registry, +) # COMPILE_TARGET_HF_RAW is re-exported for callers passing it to add_tensor(). +from .training_publisher import MxTrainingPublisher + +logger = logging.getLogger("modelexpress.nemo_rl_v2") + + +# Role string written into ``extra_parameters["role"]``. Matches the +# convention adopted by PR #2389. Receivers filter on it to disambiguate. +ROLE_TRAINER = "trainer" +ROLE_INFERENCE = "inference" +ROLE_INFERENCE_REPLICA = "inference_replica" + + +def _slice_along_axis( + tensor: torch.Tensor, axis: int, rng: tuple[int, int] +) -> torch.Tensor: + """View ``tensor[..., rng[0]:rng[1], ...]`` along ``axis``. + + Phase-4 helper: lifts a publisher's local-shard bytes into the + receiver's destination slice. Returns a view (no copy) when ``tensor`` + is already contiguous along ``axis``; otherwise yields a contiguous + clone so subsequent ``torch.cat`` is well-defined. + """ + start, end = rng + if tensor.ndim == 0 or start == 0 and end == tensor.shape[axis]: + return tensor + idx: list[slice] = [slice(None)] * tensor.ndim + idx[axis] = slice(start, end) + out = tensor[tuple(idx)] + return out.contiguous() if not out.is_contiguous() else out + + +# Synthetic tensor descriptor used as a v2 metadata sidecar. The current +# Rust MX server drops most string fields (agent_name, extra_parameters, +# metadata_endpoint, etc.) when echoing a WorkerMetadata back via +# GetMetadata, but it preserves tensor descriptors. So we abuse a +# zero-size, magic-named TensorDescriptor as the transport: the JSON v2 +# payload goes in the ``dtype`` field, which is a freeform proto3 string +# the server stores verbatim. Receivers look for this marker and pull +# v2 fields from it. +_V2_SIDECAR_NAME = "__mx_v2_meta__" + + +# Trainer world layout descriptor. Receivers can sanity-check that the +# layout they expect matches what the trainer actually published. +@dataclass(frozen=True) +class TrainerWorldLayout: + """Compact descriptor for a trainer's parallelism layout.""" + + fsdp_world_size: int = 1 + tp_world_size: int = 1 + pp_world_size: int = 1 + ep_world_size: int = 1 + + def encode(self) -> str: + return ( + f"fsdp:{self.fsdp_world_size},tp:{self.tp_world_size}," + f"pp:{self.pp_world_size},ep:{self.ep_world_size}" + ) + + @classmethod + def decode(cls, s: str) -> "TrainerWorldLayout": + kv = {p.split(":")[0]: int(p.split(":")[1]) for p in s.split(",") if ":" in p} + return cls( + fsdp_world_size=kv.get("fsdp", 1), + tp_world_size=kv.get("tp", 1), + pp_world_size=kv.get("pp", 1), + ep_world_size=kv.get("ep", 1), + ) + + +class MxV2TrainingPublisher: + """v2 trainer-side publisher. + + Wraps :class:`MxTrainingPublisher` and adds: + + - **Shape registry**: per-tensor placement + expert info, JSON-encoded + and stashed in ``extra_parameters["shape_registry"]``. + - **Rank-to-rank semantics**: every rank publishes its OWN local shard; + no allgather, no bucket pack. + - **Heartbeat**: started automatically by :meth:`mark_ready`. + - **MoE expert metadata**: per-tensor ``owned_expert_ids`` propagated + to the receiver via the registry. + + Args: + agent_name: Unique NIXL agent name (e.g. ``"nemo-rl-trainer-r3"``). + device_id: CUDA device index. + mx_server_url: MX gRPC URL. + worker_rank: Global rank within the trainer's parallelism group. + For FSDP-only this is the FSDP rank; for FSDP+TP+EP it should + map to the receiver's rank index in the same coord system. + world_layout: Total parallelism layout — receivers use it to + sanity-check expected shape. + listen_port: Optional NIXL listen port. + heartbeat: Whether to start a background heartbeat after + ``mark_ready``. Default True. + """ + + def __init__( + self, + *, + agent_name: str, + device_id: int, + mx_server_url: str, + worker_rank: int, + world_layout: TrainerWorldLayout, + listen_port: int | None = None, + heartbeat: bool = True, + ): + self._publisher = MxTrainingPublisher( + agent_name=agent_name, + device_id=device_id, + mx_server_url=mx_server_url, + listen_port=listen_port, + ) + self._worker_rank = worker_rank + self._world_layout = world_layout + self._heartbeat_enabled = heartbeat + self._heartbeat: HeartbeatThread | None = None + + self._registry: list[TensorDescriptorV2] = [] + self._registered_tensors: dict[str, torch.Tensor] = {} + self._initialized = False + + @property + def worker_rank(self) -> int: + return self._worker_rank + + @property + def mx_source_id(self) -> str | None: + return self._publisher.mx_source_id + + @property + def worker_id(self) -> str: + return self._publisher.worker_id + + def initialize(self, *, model_name: str, dtype: str = "bfloat16") -> None: + """Initialize the underlying NIXL agent + MX gRPC client.""" + self._publisher.initialize( + model_name=model_name, + tensor_parallel_size=self._world_layout.tp_world_size, + pipeline_parallel_size=self._world_layout.pp_world_size, + expert_parallel_size=self._world_layout.ep_world_size, + dtype=dtype, + training_framework="nemo_rl", + ) + self._initialized = True + logger.info( + "MxV2TrainingPublisher initialized: rank=%d layout=%s", + self._worker_rank, + self._world_layout.encode(), + ) + + def add_tensor( + self, + *, + name: str, + tensor: torch.Tensor, + is_expert: bool = False, + expert_axis: int = 0, + owned_expert_ids: tuple[int, ...] | set[int] | list[int] = (), + compile_target: str = COMPILE_TARGET_HF_RAW, + compile_metadata: dict[str, object] | None = None, + ) -> None: + """Register a tensor for publication. + + Each call appends the tensor and its descriptor to the in-flight + registry. Call :meth:`publish` once all tensors are added; that + single publish call registers everything with NIXL (once) and + emits one ``WorkerMetadata`` row. + + Args: + name: tensor's qualified state-dict name. + tensor: GPU tensor to publish. May be a DTensor or plain + tensor. **Must NOT be a materialized full tensor** — + pass ``tensor.to_local()`` for DTensors. The whole + point of v2 is to avoid the allgather. + is_expert: whether the tensor's leading axis is the MoE + expert axis (used for expert filtering). + expert_axis: axis index for the expert dimension. + owned_expert_ids: which expert IDs this rank holds. Pass + only when ``is_expert == True``. + compile_target: Phase-3a tag identifying the kernel layout + the bytes are encoded for. Defaults to ``"hf_raw"`` — + plain HF state-dict bytes, no kernel-specific layout. + Callers should pass the resolved ``ConversionEntry.compile_target`` + from their conversion registry (e.g. ``"cutlass_fp8"`` + for cutlass per-channel FP8, ``"deep_gemm_fp8"`` for + DeepGemm 128x128 blockwise). + compile_metadata: free-form key/value blob describing the + byte-affecting compile choices (block size, scale + layout, kernel version, etc.). Receivers filter on this + via :meth:`MxV2RefitReceiver.discover_v2_sources` + ``required_compile_metadata=`` so a Cutlass receiver + won't accidentally consume DeepGemm-block-256 bytes. + """ + if not self._initialized: + raise RuntimeError("call initialize() before add_tensor()") + if not tensor.is_cuda: + raise RuntimeError( + f"tensor {name!r} is not on CUDA; v2 publish requires GPU residency" + ) + + descriptor = describe_tensor( + name=name, + tensor=tensor, + rank=self._worker_rank, + fsdp_world_size=self._world_layout.fsdp_world_size, + is_expert=is_expert, + expert_axis=expert_axis, + owned_expert_ids=tuple(sorted(owned_expert_ids)), + compile_target=compile_target, + compile_metadata=compile_metadata, + ) + self._registry.append(descriptor) + # Use a key that's unique per descriptor (including any potential + # name collisions from layer publishing). For v2 we publish all + # tensors at once, so the name is sufficient. + self._registered_tensors[name] = tensor + + def publish(self, *, version: int) -> str: + """Publish all added tensors as one ``WorkerMetadata`` row. + + Returns the ``mx_source_id`` (16-hex hash) assigned by the server. + """ + if not self._initialized: + raise RuntimeError("call initialize() before publish()") + if not self._registered_tensors: + raise RuntimeError( + "no tensors added; call add_tensor() before publish()" + ) + + registry_blob = encode_registry( + self._registry, + version=version, + trainer_world_layout=self._world_layout.encode(), + ) + + # Fold the v2 metadata into the underlying publisher's + # extra_parameters via a monkey-patched _build_identity (the + # forward-compatible path) AND attach a synthetic + # ``TensorDescriptor(name=_V2_SIDECAR_NAME, dtype=)`` to the + # outgoing WorkerMetadata (the path that survives the current + # Rust server's GetMetadata field-dropping). Receivers look at + # both: identity.extra_parameters first, then the sidecar + # descriptor. + original_build_identity = self._publisher._build_identity + + def _build_identity_with_v2(step: int) -> p2p_pb2.SourceIdentity: + ident = original_build_identity(step) + ident.extra_parameters["role"] = ROLE_TRAINER + ident.extra_parameters["mx_v2"] = "1" + ident.extra_parameters["worker_rank"] = str(self._worker_rank) + ident.extra_parameters["shape_registry"] = registry_blob + ident.extra_parameters["world_layout"] = self._world_layout.encode() + return ident + + # Build the v2 sidecar payload (preserves all the same data as + # extra_parameters but in a transport the server actually echoes). + # ``shape_registry`` is intentionally embedded as a nested JSON string + # inside this JSON document — receivers parse the outer JSON with + # decode_registry's matching call to handle the inner blob. + sidecar_payload = json.dumps( + { + "mx_v2": "1", + "role": ROLE_TRAINER, + "worker_rank": int(self._worker_rank), + "training_step": int(version), + "world_layout": self._world_layout.encode(), + "framework": "nemo_rl", + "shape_registry": registry_blob, + }, + separators=(",", ":"), + ) + + # Wrap the agent_name with v2 markers (legacy-server fallback path 2). + original_agent_name = self._publisher._agent_name + self._publisher._agent_name = ( + f"mx_v2|{ROLE_TRAINER}|rank={self._worker_rank}|" + f"version={int(version)}|orig={original_agent_name}" + ) + self._publisher._build_identity = _build_identity_with_v2 # type: ignore[method-assign] + + # Wrap _build_tensor_protos to append the sidecar descriptor. + original_build_tensor_protos = self._publisher._build_tensor_protos + + def _build_tensor_protos_with_sidecar(descriptors): + protos = original_build_tensor_protos(descriptors) + sidecar = p2p_pb2.TensorDescriptor( + name=_V2_SIDECAR_NAME, + addr=0, + size=0, + device_id=0, + dtype=sidecar_payload, + ) + protos.append(sidecar) + return protos + + self._publisher._build_tensor_protos = _build_tensor_protos_with_sidecar # type: ignore[method-assign] + + try: + mx_source_id = self._publisher.publish_weights( + named_tensors=self._registered_tensors, + step=int(version), + worker_rank=self._worker_rank, + ) + finally: + self._publisher._build_identity = original_build_identity # type: ignore[method-assign] + self._publisher._agent_name = original_agent_name + self._publisher._build_tensor_protos = original_build_tensor_protos # type: ignore[method-assign] + + logger.info( + "MxV2 publish: rank=%d version=%d tensors=%d mx_source_id=%s", + self._worker_rank, + version, + len(self._registered_tensors), + mx_source_id, + ) + return mx_source_id + + def mark_ready(self) -> bool: + """Mark this source as READY. Starts heartbeat if enabled.""" + ok = self._publisher.mark_ready(worker_rank=self._worker_rank) + if ok and self._heartbeat_enabled and self._heartbeat is None: + self._start_heartbeat() + return ok + + def _start_heartbeat(self) -> None: + if self._publisher._client is None or self._publisher._nixl is None: + logger.warning("cannot start heartbeat: publisher not initialized") + return + self._heartbeat = HeartbeatThread( + mx_client=self._publisher._client, + mx_source_id=self._publisher.mx_source_id or "", + worker_id=self._publisher.worker_id, + worker_rank=self._worker_rank, + nixl_manager=self._publisher._nixl, + ) + self._heartbeat.start() + + def shutdown(self) -> None: + """Stop heartbeat (marks STALE) and tear down the publisher.""" + if self._heartbeat is not None: + self._heartbeat.stop() + self._heartbeat = None + self._publisher.shutdown() + self._initialized = False + + +@dataclass +class V2SourceCandidate: + """A discovered source with v2 metadata parsed. + + ``compile_targets`` is the set of distinct ``TensorDescriptorV2.compile_target`` + values present across the candidate's registry. A receiver filters on this + via :meth:`MxV2RefitReceiver.discover_v2_sources` (``compile_target_filter=``). + The most common shapes: + + - ``{"hf_raw"}`` — clean HF state-dict bytes; any kernel-aware receiver + can compile from it. + - ``{"deep_gemm_fp8"}`` — already quantised + reordered for DeepGemm. + - ``{"hf_raw", "deep_gemm_fp8"}`` — mixed: some tensors raw, some compiled + (rare but legal; receivers must check per-tensor). + """ + + ref: SourceRef + role: str # "trainer" | "inference_replica" + worker_rank: int + registry: dict | None # decoded registry; None for inference_replica + owned_experts_per_layer: dict[int, set[int]] # layer_idx → expert IDs + updated_at: int # ms epoch + compile_targets: frozenset[str] = frozenset({COMPILE_TARGET_HF_RAW}) + + +@dataclass +class TargetTPLayout: + """What slice of the global tensor a Phase-4 receiver wants. + + Phase 4 (mixed-TP / multi-source slice discovery, see post-#2389 RFC §5). + A receiver running at inference-time describes its local view by: + + - ``world_size``: the inference-side world that's splitting the tensor + (e.g. inference TP=8 even if trainer was TP=4). + - ``rank``: this receiver's rank within ``world_size``. Used to compute + an even slice by default. + - ``shard_axis``: which tensor axis is sharded across that world. + - ``target_range``: optional explicit ``(start, end)`` along + ``shard_axis`` to override the default even-split math. Necessary + for uneven layouts (e.g. expert sharding with custom owner maps). + + The publisher's ``placement_kind`` + ``local_shard_range`` are looked up + per tensor; the planner intersects ``target_range`` against every + publisher slice and emits the minimal candidate set. + """ + + world_size: int + rank: int + shard_axis: int = 0 + target_range: tuple[int, int] | None = None + + +@dataclass +class SliceSource: + """One source contribution toward filling a receiver's target slice. + + Emitted by :class:`SliceCoveragePlan`. The receiver issues one NIXL + RDMA read per ``SliceSource``, copying ``src_range`` bytes from the + candidate's buffer into ``dst_range`` of the local destination. + """ + + candidate: V2SourceCandidate + tensor_name: str + src_range: tuple[int, int] + dst_range: tuple[int, int] + shard_axis: int + + +@dataclass +class _TensorPlan: + """Internal: result of planning one tensor's coverage.""" + + contributions: list[SliceSource] + reason: str + + +@dataclass +class SliceCoveragePlan: + """Result of :meth:`MxV2RefitReceiver.discover_v2_sources_for_slice`. + + Fields: + candidates: every v2 candidate that passed the filter. + per_tensor_sources: per-tensor list of :class:`SliceSource` + describing how to fill that tensor's target slice. An empty + entry means no plan was found (see ``missing``). + missing: list of ``"name: reason"`` for tensors that couldn't be + fully covered. If non-empty, the receiver should treat the + plan as failed. + target_layout: echoed back for convenience. + legacy_single_source: True iff the picker found candidates but + none carried a v2 registry — in that case ``per_tensor_sources`` + is empty and the caller should fall back to + :meth:`MxV2RefitReceiver.receive_from` with ``candidates[0]``. + """ + + candidates: list[V2SourceCandidate] + per_tensor_sources: dict[str, list[SliceSource]] + missing: list[str] + target_layout: TargetTPLayout + legacy_single_source: bool = False + + @property + def fully_covered(self) -> bool: + return not self.missing and ( + self.legacy_single_source or bool(self.per_tensor_sources) + ) + + +class MxV2RefitReceiver: + """v2 inference-side receiver. + + Wraps :class:`MxRefitReceiver` and adds: + + - **Same-rank source selection**: by default, picks a candidate with + ``worker_rank == self.worker_rank``. Falls back to other ranks only + if explicitly requested. + + - **Freshest-first dedup**: when multiple candidates match the rank + filter, picks the one with the latest ``updated_at``. (Same fix + as PrimeRL's runtime patch — applied as the default here.) + + - **Tree fan-out**: after a successful receive, optionally calls + :meth:`publish_self_as_source` to make this rank's buffers + available to subsequent receivers. + + - **Expert filtering**: when ``my_owned_experts_per_layer`` is set, + receives only the slices of expert tensors that this rank actually + uses. + """ + + def __init__( + self, + *, + agent_name: str, + device_id: int, + mx_server_url: str, + worker_rank: int, + listen_port: int | None = None, + ): + self._receiver = MxRefitReceiver( + agent_name=agent_name, + device_id=device_id, + mx_server_url=mx_server_url, + listen_port=listen_port, + ) + self._worker_rank = worker_rank + self._initialized = False + self._registered_buffers: dict[str, torch.Tensor] = {} + + # Metrics surface for benchmarks / dashboards. Discovery numbers + # are at the v2 layer (catalog walk + per-instance get_metadata); + # the per-transfer RDMA numbers live on the wrapped MxRefitReceiver + # in `self._receiver.last_stats` and `self._receiver.history`. + self._last_discovery_seconds: float = 0.0 + self._last_discovery_candidates: int = 0 + + @property + def worker_rank(self) -> int: + return self._worker_rank + + def initialize( + self, + *, + model_tensors: dict[str, torch.Tensor] | None = None, + ) -> None: + """Initialize NIXL agent + MX client. Optionally register receive buffers.""" + self._receiver.initialize(model_tensors=model_tensors) + if model_tensors: + self._registered_buffers = dict(model_tensors) + self._initialized = True + logger.info( + "MxV2RefitReceiver initialized: rank=%d buffers=%d", + self._worker_rank, + len(self._registered_buffers), + ) + + def discover_v2_sources( + self, + *, + model_name: str, + min_version: int = 0, + same_rank_only: bool = True, + include_replicas: bool = True, + compile_target_filter: set[str] | frozenset[str] | None = None, + required_compile_metadata: dict[str, object] | None = None, + ) -> list[V2SourceCandidate]: + """List candidate v2 sources, filtering and sorting per the v2 rules. + + Args: + model_name: model name to filter on. + min_version: only return sources whose ``version`` (== training + step) is at least this. + same_rank_only: if True (default), only return candidates whose + ``worker_rank`` equals this receiver's rank. + include_replicas: whether to include other inference ranks that + have already received and republished. Combined with + ``same_rank_only``, this means "same-rank trainer + any + same-rank inference replica". + compile_target_filter: receiver-side whitelist of acceptable + ``compile_target`` strings. A candidate is admitted only if + *every* tensor in its registry has a compile_target in this + set (mixed-layout candidates are rejected — see RFC §5). + ``None`` (default) accepts everything, matching pre-Phase-3 + behaviour. + required_compile_metadata: optional kv pairs that every tensor's + ``compile_metadata`` must match. Use for pinning block sizes, + scale layouts, kernel versions, etc. + + Returns: + Candidates sorted by freshness (largest ``updated_at`` first). + Empty list if none matched. + """ + if not self._initialized: + raise RuntimeError("call initialize() before discover_v2_sources()") + + # Track catalog discovery time on the underlying receiver's metrics + # so benchmarks can see control-plane latency vs RDMA latency. + import time as _time + discovery_start = _time.monotonic() + + client = self._receiver._client + assert client is not None, "_receiver._client must be set after initialize()" + try: + response = client.list_sources( + status_filter=p2p_pb2.SOURCE_STATUS_READY, + ) + except Exception as e: # noqa: BLE001 + logger.warning("list_sources failed: %s", e) + self._last_discovery_seconds = _time.monotonic() - discovery_start + return [] + + candidates: list[V2SourceCandidate] = [] + for instance in response.instances: + if instance.model_name != model_name: + continue + + # Resolve the full identity to read v2 metadata. + try: + meta = client.get_metadata( + instance.mx_source_id, instance.worker_id + ) + except Exception as e: # noqa: BLE001 + logger.debug( + "get_metadata failed for %s: %s", instance.worker_id, e + ) + continue + if not getattr(meta, "found", False): + continue + + # Read v2 metadata. We try three transports in order: + # (a) SourceIdentity.extra_parameters (the cleanest path; works + # once the Rust server populates GetMetadataResponse.identity). + # (b) Synthetic TensorDescriptor sidecar named ``__mx_v2_meta__`` + # (preserved by the current Rust server; the path the + # prototype actually uses today). + # (c) WorkerMetadata.agent_name string-encoded marker (legacy). + identity = getattr(meta, "identity", None) + extra: dict[str, str] = ( + dict(identity.extra_parameters) + if identity is not None and identity.extra_parameters + else {} + ) + if not extra: + # Sidecar transport: scan tensors for the magic marker. + for td in meta.worker.tensors: + if td.name == _V2_SIDECAR_NAME and td.dtype: + try: + sidecar = json.loads(td.dtype) + if isinstance(sidecar, dict): + for k, v in sidecar.items(): + extra[k] = str(v) + except (json.JSONDecodeError, TypeError): + pass + break + if not extra: + # Agent-name transport: "mx_v2||rank=N|version=K|orig=...". + agent_name = getattr(meta.worker, "agent_name", "") or "" + if agent_name.startswith("mx_v2|"): + parts = agent_name.split("|") + if len(parts) >= 4: + extra["mx_v2"] = "1" + extra["role"] = parts[1] + for piece in parts[2:]: + if "=" in piece: + k, v = piece.split("=", 1) + if k == "rank": + extra["worker_rank"] = v + elif k == "version": + extra["training_step"] = v + if extra.get("mx_v2") != "1": + # Not a v2 source; ignore. + continue + role = extra.get("role", "") + if role == ROLE_TRAINER and not include_replicas and False: + pass # always include trainer + if role not in (ROLE_TRAINER, ROLE_INFERENCE_REPLICA): + continue + if role == ROLE_INFERENCE_REPLICA and not include_replicas: + continue + + try: + src_rank = int(extra.get("worker_rank", "-1")) + except ValueError: + continue + if same_rank_only and src_rank != self._worker_rank: + continue + + try: + version = int(extra.get("training_step", "0")) + except ValueError: + continue + if version < min_version: + continue + + registry_blob = extra.get("shape_registry", "") + registry = decode_registry(registry_blob) if registry_blob else None + + # Phase 3b: enforce compile_target_filter / required_compile_metadata. + # We require ALL tensors in the registry to match — partial matches + # would mean the receiver consumes some bytes correctly and silently + # corrupts others. If the candidate has no registry (e.g. v0 trainer + # or an inference replica that didn't republish a registry), we + # admit it only when no filter is set, so callers explicitly opt in + # to compile-aware behaviour. + descriptors = [ + t + for t in (registry["tensors"] if registry else []) + if isinstance(t, TensorDescriptorV2) + ] + if compile_target_filter is not None or required_compile_metadata: + if not descriptors: + logger.debug( + "skipping candidate worker_id=%s: compile filter set " + "but candidate has no v2 registry", + instance.worker_id, + ) + continue + allowed = ( + frozenset(compile_target_filter) + if compile_target_filter is not None + else None + ) + ok = all( + compile_target_matches( + d, + allowed_targets=allowed, + required_metadata=required_compile_metadata, + ) + for d in descriptors + ) + if not ok: + logger.debug( + "skipping candidate worker_id=%s: compile filter mismatch " + "(targets=%s, want=%s)", + instance.worker_id, + sorted({d.compile_target for d in descriptors}), + sorted(compile_target_filter) if compile_target_filter else "*", + ) + continue + + compile_targets = ( + frozenset(d.compile_target for d in descriptors) + if descriptors + else frozenset({COMPILE_TARGET_HF_RAW}) + ) + + owned_blob = extra.get("owned_experts_per_layer", "") + owned_experts_per_layer: dict[int, set[int]] = {} + if owned_blob: + # encoding: "L0:0,1,2|L1:3,4,5" + for chunk in owned_blob.split("|"): + if ":" not in chunk: + continue + lid, ids = chunk.split(":", 1) + owned_experts_per_layer[int(lid.lstrip("L"))] = decode_expert_set( + ids + ) + + updated_at = int(getattr(meta.worker, "updated_at", 0) or 0) + + candidates.append( + V2SourceCandidate( + ref=SourceRef( + mx_source_id=instance.mx_source_id, + worker_id=instance.worker_id, + model_name=instance.model_name, + worker_rank=src_rank, + training_step=version, + ), + role=role, + worker_rank=src_rank, + registry=registry, + owned_experts_per_layer=owned_experts_per_layer, + updated_at=updated_at, + compile_targets=compile_targets, + ) + ) + + # Topology score: prefer trainer over inference_replica (trainer is + # always authoritative); within that, prefer freshest. + candidates.sort( + key=lambda c: ( + 0 if c.role == ROLE_TRAINER else 1, + -c.updated_at, + ) + ) + self._last_discovery_seconds = _time.monotonic() - discovery_start + self._last_discovery_candidates = len(candidates) + return candidates + + def pick_best_source( + self, + candidates: list[V2SourceCandidate], + *, + needed_experts_per_layer: dict[int, set[int]] | None = None, + ) -> V2SourceCandidate | None: + """Pick the best candidate. Optionally requires expert coverage. + + If ``needed_experts_per_layer`` is set, the candidate must own a + superset of the requested experts (or be a trainer with full info). + """ + if not candidates: + return None + if needed_experts_per_layer is None: + return candidates[0] + + for cand in candidates: + if cand.role == ROLE_TRAINER: + # Trainer publishes its rank's owned set in the registry; if + # we need experts the trainer doesn't own, no single source + # has them and the caller has to multi-source. v0 punts. + return cand + covers_all = all( + needed.issubset(cand.owned_experts_per_layer.get(layer, set())) + for layer, needed in needed_experts_per_layer.items() + ) + if covers_all: + return cand + return None + + def receive_from( + self, + candidate: V2SourceCandidate, + *, + timeout_seconds: float = 300.0, + ) -> Iterator[tuple[str, torch.Tensor]]: + """Pull the candidate's tensors via NIXL RDMA into our pre-registered buffers. + + Wraps :meth:`MxRefitReceiver.receive_weights`. Yielded tensors are + the same buffers that were registered at ``initialize`` time. + """ + yield from self._receiver.receive_weights( + candidate.ref, timeout_seconds=timeout_seconds + ) + + def receive_from_scratch( + self, + candidate: V2SourceCandidate, + *, + timeout_seconds: float = 300.0, + tensor_shapes: dict[str, tuple[int, ...]] | None = None, + ) -> Iterator[tuple[str, torch.Tensor]]: + """Pull the candidate's tensors via NIXL into receiver-allocated buffers. + + Wraps :meth:`MxRefitReceiver.receive_weights_scratch`. Use this + when the caller has no pre-registered model parameters to + receive into — e.g. cold-start in a vLLM worker before + ``model.load_weights()``, or the benchmark harness. Yielded + tensors are short-lived scratch buffers; copy them out or feed + them through ``load_weights`` before the next call. + """ + yield from self._receiver.receive_weights_scratch( + candidate.ref, + timeout_seconds=timeout_seconds, + tensor_shapes=tensor_shapes, + ) + + def receive_via_plan( + self, + plan: "SliceCoveragePlan", + *, + timeout_seconds: float = 300.0, + tensor_shapes: dict[str, tuple[int, ...]] | None = None, + ) -> Iterator[tuple[str, torch.Tensor]]: + """Multi-source receive driven by a :class:`SliceCoveragePlan` (Phase 4). + + For each candidate in the plan, issue a scratch RDMA receive and + copy the publisher's bytes into the receiver-local slice given by + ``SliceSource.dst_range``. Yields one ``(name, tensor)`` per tensor + in the plan, where ``tensor`` is the stitched view of the + receiver's requested slice. + + **v0 caveats** (intentional, see post-#2389 RFC §5): + + 1. We issue one full ``receive_weights_scratch`` per contributing + candidate. If the publisher's local shard is larger than the + receiver's slice (the common case for trainer-TP=4 → inference-TP=8), + this transfers more bytes than strictly necessary. Phase 4.5 will + push a byte-offset/byte-length argument into the NIXL transfer + manager so we issue partial reads. + + 2. We do an in-process ``torch.cat`` along ``shard_axis`` to stitch + the contributions. For 2 contributions this is one extra D2D copy + per tensor; for N=4+ it would warrant a fused kernel. v0 doesn't + bother. + + 3. Falls back to single-source :meth:`receive_from` if + ``plan.legacy_single_source`` is True (no v2 registry on any + candidate, e.g. talking to a v1-only deployment). + """ + if not plan.fully_covered: + raise RuntimeError( + f"plan is not fully covered; missing={plan.missing}" + ) + if plan.legacy_single_source: + if not plan.candidates: + raise RuntimeError("legacy_single_source plan has no candidates") + yield from self.receive_from( + plan.candidates[0], timeout_seconds=timeout_seconds + ) + return + + # Walk contributing candidates; cache scratch results so we only + # issue one RDMA pull per candidate even if several tensors share it. + cand_to_scratch: dict[str, dict[str, torch.Tensor]] = {} + for tensor_name, contributions in plan.per_tensor_sources.items(): + for src in contributions: + key = src.candidate.ref.worker_id + if key in cand_to_scratch: + continue + scratch = { + name: tensor + for name, tensor in self._receiver.receive_weights_scratch( + src.candidate.ref, + timeout_seconds=timeout_seconds, + tensor_shapes=tensor_shapes, + ) + } + cand_to_scratch[key] = scratch + + for tensor_name, contributions in plan.per_tensor_sources.items(): + # If a single contribution covers the slice, no stitching needed. + if len(contributions) == 1: + src = contributions[0] + buf = cand_to_scratch[src.candidate.ref.worker_id][tensor_name] + yield tensor_name, _slice_along_axis( + buf, src.shard_axis, src.src_range + ) + continue + + slices = [] + for src in contributions: + buf = cand_to_scratch[src.candidate.ref.worker_id][tensor_name] + slices.append(_slice_along_axis(buf, src.shard_axis, src.src_range)) + stitched = torch.cat(slices, dim=contributions[0].shard_axis) + yield tensor_name, stitched + + def discover_v2_sources_for_slice( + self, + *, + model_name: str, + target_layout: "TargetTPLayout", + min_version: int = 0, + same_rank_only: bool = False, + include_replicas: bool = True, + compile_target_filter: set[str] | frozenset[str] | None = None, + required_compile_metadata: dict[str, object] | None = None, + ) -> "SliceCoveragePlan": + """Phase-4 multi-source picker: returns the minimal candidate set covering ``target_layout``. + + This is the entry point for **mixed-TP** receivers. The receiver + states the slice it wants (``target_layout`` describes its own TP + world size, this rank's TP rank, and which axis is the shard axis), + and we walk all v2 candidates to find the smallest set whose union + of ``local_shard_range`` covers that slice — *per tensor*. + + Unlike :meth:`discover_v2_sources`, ``same_rank_only`` defaults to + ``False`` here: with mixed-TP the obvious case is "trainer TP=4, + inference TP=8, so each inference rank pulls from one or two trainer + ranks", which inherently requires reading across publisher ranks. + + Returns a :class:`SliceCoveragePlan` whose ``per_tensor_sources`` maps + tensor name → ordered list of ``(candidate, src_range, dst_range)`` + slice descriptors. If any tensor cannot be fully covered, the plan's + ``missing`` list is non-empty and the caller should error out. + """ + if not self._initialized: + raise RuntimeError( + "call initialize() before discover_v2_sources_for_slice()" + ) + + candidates = self.discover_v2_sources( + model_name=model_name, + min_version=min_version, + same_rank_only=same_rank_only, + include_replicas=include_replicas, + compile_target_filter=compile_target_filter, + required_compile_metadata=required_compile_metadata, + ) + + per_tensor_sources: dict[str, list[SliceSource]] = {} + missing: list[str] = [] + covered_tensors: set[str] = set() + + # Aggregate candidates by tensor name. A tensor is "covered" when the + # union of admitted candidates' local_shard_ranges (on the requested + # shard_axis) contains the target slice. + all_tensor_names: set[str] = set() + for cand in candidates: + if not cand.registry: + continue + for td in cand.registry.get("tensors", []): + if isinstance(td, TensorDescriptorV2): + all_tensor_names.add(td.name) + + for name in sorted(all_tensor_names): + # Per-tensor slice planning. Use the candidate's per-tensor + # global_shape + placement_kind to decide what THIS receiver needs. + # + # Slice math is intentionally simple in v0: + # - REPLICATE: any candidate provides the full bytes. Pick freshest. + # - SHARD on axis A == target_layout.shard_axis: receiver's slice + # is [target_start, target_end) over the global axis-A extent; + # we accumulate candidates whose local_shard_range intersects + # this slice, clipping each contribution to the wanted range. + # - SHARD on a different axis: not handled in v0 — emit a missing + # entry with a precise reason so the caller can fall back. + # - PARTIAL: not handled in v0. + chosen = self._plan_tensor_slice( + name=name, + candidates=candidates, + target_layout=target_layout, + ) + if chosen.contributions: + per_tensor_sources[name] = chosen.contributions + covered_tensors.add(name) + else: + missing.append(f"{name}: {chosen.reason}") + + # If the registry of every candidate is empty (e.g. transport drop) we + # should still produce a plan with one default contribution per + # candidate so single-source receivers behave the same way they used + # to. Detect that legacy mode. + legacy_mode = not all_tensor_names and candidates + if legacy_mode: + return SliceCoveragePlan( + candidates=candidates, + per_tensor_sources={}, + missing=[], + target_layout=target_layout, + legacy_single_source=True, + ) + + return SliceCoveragePlan( + candidates=candidates, + per_tensor_sources=per_tensor_sources, + missing=missing, + target_layout=target_layout, + legacy_single_source=False, + ) + + @staticmethod + def _plan_tensor_slice( + *, + name: str, + candidates: list[V2SourceCandidate], + target_layout: "TargetTPLayout", + ) -> "_TensorPlan": + # Find every (candidate, td) pair publishing this tensor. + published: list[tuple[V2SourceCandidate, TensorDescriptorV2]] = [] + for cand in candidates: + if not cand.registry: + continue + for td in cand.registry.get("tensors", []): + if isinstance(td, TensorDescriptorV2) and td.name == name: + published.append((cand, td)) + break + if not published: + return _TensorPlan(contributions=[], reason="no publishers") + + # All publishers must agree on global shape + dtype + placement kind. + first_td = published[0][1] + for _, td in published[1:]: + if td.global_shape != first_td.global_shape: + return _TensorPlan( + contributions=[], + reason=f"shape disagreement {first_td.global_shape} vs {td.global_shape}", + ) + + if first_td.placement_kind == "REPLICATE": + # Any candidate satisfies. Caller's already sorted candidates by + # freshness in discover_v2_sources. + cand, td = published[0] + return _TensorPlan( + contributions=[ + SliceSource( + candidate=cand, + tensor_name=name, + src_range=(0, first_td.global_shape[0]) + if first_td.global_shape + else (0, 0), + dst_range=(0, first_td.global_shape[0]) + if first_td.global_shape + else (0, 0), + shard_axis=0, + ) + ], + reason="ok-replicate", + ) + + if first_td.placement_kind != "SHARD": + return _TensorPlan( + contributions=[], + reason=f"placement_kind={first_td.placement_kind} not supported in v0", + ) + + if first_td.shard_axis != target_layout.shard_axis: + return _TensorPlan( + contributions=[], + reason=( + f"shard_axis mismatch: publisher={first_td.shard_axis} " + f"target={target_layout.shard_axis}" + ), + ) + + axis_total = first_td.global_shape[target_layout.shard_axis] + # Receiver's wanted slice. Even split for v0; uneven splits are + # handled by the caller passing a custom ``target_layout`` that + # already encodes the start/end pair. + if target_layout.target_range is not None: + t_start, t_end = target_layout.target_range + else: + chunk = axis_total // target_layout.world_size + t_start = target_layout.rank * chunk + t_end = t_start + chunk + + # Walk publishers in freshness order, accumulate intersections. + contributions: list[SliceSource] = [] + covered_until = t_start + # Sort by start of local_shard_range so we accumulate left-to-right. + published_sorted = sorted( + published, + key=lambda pair: ( + pair[1].local_shard_range[0] if pair[1].local_shard_range else 0 + ), + ) + for cand, td in published_sorted: + if td.local_shard_range is None: + continue + p_start, p_end = td.local_shard_range + inter_start = max(p_start, t_start) + inter_end = min(p_end, t_end) + if inter_start >= inter_end: + continue + if inter_start > covered_until: + # gap — coverage incomplete. + return _TensorPlan( + contributions=[], + reason=( + f"coverage gap at axis {target_layout.shard_axis} " + f"[{covered_until}, {inter_start})" + ), + ) + contributions.append( + SliceSource( + candidate=cand, + tensor_name=name, + src_range=(inter_start - p_start, inter_end - p_start), + dst_range=(inter_start - t_start, inter_end - t_start), + shard_axis=target_layout.shard_axis, + ) + ) + covered_until = inter_end + if covered_until >= t_end: + break + if covered_until < t_end: + return _TensorPlan( + contributions=[], + reason=( + f"coverage gap at axis {target_layout.shard_axis} " + f"[{covered_until}, {t_end})" + ), + ) + return _TensorPlan(contributions=contributions, reason="ok-shard") + + def publish_self_as_source( + self, + *, + version: int, + model_name: str, + ) -> str | None: + """Make this receiver's buffers available to other receivers. + + Implements the TensorHub pipeline-replication trick: after we've + successfully received a version, we publish ourselves as an + ``inference_replica`` source so that any rank N receiver who hasn't + yet pulled can pull from us instead of contending on the trainer. + """ + if not self._registered_buffers: + logger.warning( + "publish_self_as_source: no registered buffers; skipping" + ) + return None + client = self._receiver._client + nixl = self._receiver._nixl + if client is None or nixl is None: + logger.warning( + "publish_self_as_source: receiver not initialized; skipping" + ) + return None + + # Build a lightweight identity declaring ourselves as a replica. + identity = p2p_pb2.SourceIdentity( + model_name=model_name, + mx_source_type=p2p_pb2.MX_SOURCE_TYPE_WEIGHTS, + backend_framework=p2p_pb2.BACKEND_FRAMEWORK_UNKNOWN, + tensor_parallel_size=0, + pipeline_parallel_size=0, + expert_parallel_size=0, + dtype="bfloat16", # not load-bearing for replica; receivers ignore + quantization="", + extra_parameters={ + "role": ROLE_INFERENCE_REPLICA, + "mx_v2": "1", + "worker_rank": str(self._worker_rank), + "training_step": str(int(version)), + "training_framework": "nemo_rl", + }, + ) + + # Build a tensor-descriptor list from our already-registered buffers. + from .types import TensorDescriptor + + descriptors = nixl.tensor_descriptors # already populated at register time + worker_meta = p2p_pb2.WorkerMetadata( + worker_rank=self._worker_rank, + nixl_metadata=nixl.nixl_metadata, + tensors=[ + p2p_pb2.TensorDescriptor( + name=d.name, + addr=d.addr, + size=d.size, + device_id=d.device_id, + dtype=d.dtype, + ) + for d in descriptors + ], + status=p2p_pb2.SOURCE_STATUS_READY, + agent_name=self._receiver._agent_name, + ) + + try: + mx_source_id = client.publish_metadata( + identity=identity, + worker=worker_meta, + worker_id=self._receiver._worker_id + if hasattr(self._receiver, "_worker_id") + else "", + ) + except Exception as e: # noqa: BLE001 + logger.warning( + "publish_self_as_source failed: %s", e, exc_info=True + ) + return None + logger.info( + "Published self as inference_replica: rank=%d version=%d mx_source_id=%s", + self._worker_rank, + version, + mx_source_id, + ) + return mx_source_id + + def shutdown(self) -> None: + # MxRefitReceiver has no shutdown method in the existing code; the + # NIXL transfer manager and MxClient are torn down by Python's gc. + # Future: when refit_receiver gains a shutdown(), call it here. + self._initialized = False + + +__all__ = [ + "MxV2RefitReceiver", + "MxV2TrainingPublisher", + "ROLE_INFERENCE", + "ROLE_INFERENCE_REPLICA", + "ROLE_TRAINER", + "TrainerWorldLayout", + "V2SourceCandidate", +] diff --git a/modelexpress_client/python/modelexpress/p2p_pb2.py b/modelexpress_client/python/modelexpress/p2p_pb2.py index c1a89ffe..38f7a8d3 100644 --- a/modelexpress_client/python/modelexpress/p2p_pb2.py +++ b/modelexpress_client/python/modelexpress/p2p_pb2.py @@ -27,7 +27,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tp2p.proto\x12\x11model_express.p2p\"\xce\x03\n\x0eSourceIdentity\x12\x12\n\nmx_version\x18\x01 \x01(\t\x12\x37\n\x0emx_source_type\x18\x02 \x01(\x0e\x32\x1f.model_express.p2p.MxSourceType\x12\x12\n\nmodel_name\x18\x03 \x01(\t\x12>\n\x11\x62\x61\x63kend_framework\x18\x04 \x01(\x0e\x32#.model_express.p2p.BackendFramework\x12\x1c\n\x14tensor_parallel_size\x18\x05 \x01(\r\x12\x1e\n\x16pipeline_parallel_size\x18\x06 \x01(\r\x12\x1c\n\x14\x65xpert_parallel_size\x18\x07 \x01(\r\x12\r\n\x05\x64type\x18\x08 \x01(\t\x12\x14\n\x0cquantization\x18\t \x01(\t\x12P\n\x10\x65xtra_parameters\x18\n \x03(\x0b\x32\x36.model_express.p2p.SourceIdentity.ExtraParametersEntry\x12\x10\n\x08revision\x18\x0b \x01(\t\x1a\x36\n\x14\x45xtraParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"^\n\x10TensorDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x61\x64\x64r\x18\x02 \x01(\x04\x12\x0c\n\x04size\x18\x03 \x01(\x04\x12\x11\n\tdevice_id\x18\x04 \x01(\r\x12\r\n\x05\x64type\x18\x05 \x01(\t\"\xc0\x02\n\x0eWorkerMetadata\x12\x13\n\x0bworker_rank\x18\x01 \x01(\r\x12\x17\n\rnixl_metadata\x18\x02 \x01(\x0cH\x00\x12$\n\x1atransfer_engine_session_id\x18\n \x01(\tH\x00\x12\x34\n\x07tensors\x18\x03 \x03(\x0b\x32#.model_express.p2p.TensorDescriptor\x12/\n\x06status\x18\x04 \x01(\x0e\x32\x1f.model_express.p2p.SourceStatus\x12\x12\n\nupdated_at\x18\x05 \x01(\x03\x12\x19\n\x11metadata_endpoint\x18\x06 \x01(\t\x12\x12\n\nagent_name\x18\x07 \x01(\t\x12\x1c\n\x14worker_grpc_endpoint\x18\x08 \x01(\tB\x12\n\x10\x62\x61\x63kend_metadata\"0\n\x18GetTensorManifestRequest\x12\x14\n\x0cmx_source_id\x18\x01 \x01(\t\"\xab\x01\n\x19GetTensorManifestResponse\x12\x34\n\x07tensors\x18\x01 \x03(\x0b\x32#.model_express.p2p.TensorDescriptor\x12\x14\n\x0cmx_source_id\x18\x02 \x01(\t\x12\x19\n\x11metadata_endpoint\x18\x03 \x01(\t\x12\x12\n\nagent_name\x18\x04 \x01(\t\x12\x13\n\x0bworker_rank\x18\x05 \x01(\r\"\x93\x01\n\x16PublishMetadataRequest\x12\x33\n\x08identity\x18\x01 \x01(\x0b\x32!.model_express.p2p.SourceIdentity\x12\x31\n\x06worker\x18\x02 \x01(\x0b\x32!.model_express.p2p.WorkerMetadata\x12\x11\n\tworker_id\x18\x03 \x01(\t\"d\n\x17PublishMetadataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x14\n\x0cmx_source_id\x18\x03 \x01(\t\x12\x11\n\tworker_id\x18\x04 \x01(\t\"e\n\x11SourceInstanceRef\x12\x14\n\x0cmx_source_id\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\t\x12\x12\n\nmodel_name\x18\x03 \x01(\t\x12\x13\n\x0bworker_rank\x18\x04 \x01(\r\"\x98\x01\n\x12ListSourcesRequest\x12\x33\n\x08identity\x18\x01 \x01(\x0b\x32!.model_express.p2p.SourceIdentity\x12;\n\rstatus_filter\x18\x02 \x01(\x0e\x32\x1f.model_express.p2p.SourceStatusH\x00\x88\x01\x01\x42\x10\n\x0e_status_filter\"N\n\x13ListSourcesResponse\x12\x37\n\tinstances\x18\x01 \x03(\x0b\x32$.model_express.p2p.SourceInstanceRef\"=\n\x12GetMetadataRequest\x12\x14\n\x0cmx_source_id\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\t\"\x80\x01\n\x13GetMetadataResponse\x12\r\n\x05\x66ound\x18\x01 \x01(\x08\x12\x31\n\x06worker\x18\x02 \x01(\x0b\x32!.model_express.p2p.WorkerMetadata\x12\x14\n\x0cmx_source_id\x18\x03 \x01(\t\x12\x11\n\tworker_id\x18\x04 \x01(\t\"\x84\x01\n\x13UpdateStatusRequest\x12\x14\n\x0cmx_source_id\x18\x01 \x01(\t\x12\x13\n\x0bworker_rank\x18\x02 \x01(\r\x12/\n\x06status\x18\x03 \x01(\x0e\x32\x1f.model_express.p2p.SourceStatus\x12\x11\n\tworker_id\x18\x04 \x01(\t\"8\n\x14UpdateStatusResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t*\x8a\x01\n\x10\x42\x61\x63kendFramework\x12\x1d\n\x19\x42\x41\x43KEND_FRAMEWORK_UNKNOWN\x10\x00\x12\x1a\n\x16\x42\x41\x43KEND_FRAMEWORK_VLLM\x10\x01\x12\x1c\n\x18\x42\x41\x43KEND_FRAMEWORK_SGLANG\x10\x02\x12\x1d\n\x19\x42\x41\x43KEND_FRAMEWORK_TRT_LLM\x10\x03*b\n\x0cMxSourceType\x12\x1a\n\x16MX_SOURCE_TYPE_WEIGHTS\x10\x00\x12\x17\n\x13MX_SOURCE_TYPE_LORA\x10\x01\x12\x1d\n\x19MX_SOURCE_TYPE_CUDA_GRAPH\x10\x02*{\n\x0cSourceStatus\x12\x19\n\x15SOURCE_STATUS_UNKNOWN\x10\x00\x12\x1e\n\x1aSOURCE_STATUS_INITIALIZING\x10\x01\x12\x17\n\x13SOURCE_STATUS_READY\x10\x02\x12\x17\n\x13SOURCE_STATUS_STALE\x10\x03\x32\x93\x03\n\nP2pService\x12h\n\x0fPublishMetadata\x12).model_express.p2p.PublishMetadataRequest\x1a*.model_express.p2p.PublishMetadataResponse\x12\\\n\x0bListSources\x12%.model_express.p2p.ListSourcesRequest\x1a&.model_express.p2p.ListSourcesResponse\x12\\\n\x0bGetMetadata\x12%.model_express.p2p.GetMetadataRequest\x1a&.model_express.p2p.GetMetadataResponse\x12_\n\x0cUpdateStatus\x12&.model_express.p2p.UpdateStatusRequest\x1a\'.model_express.p2p.UpdateStatusResponse2\x7f\n\rWorkerService\x12n\n\x11GetTensorManifest\x12+.model_express.p2p.GetTensorManifestRequest\x1a,.model_express.p2p.GetTensorManifestResponseb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\tp2p.proto\x12\x11model_express.p2p\"\xce\x03\n\x0eSourceIdentity\x12\x12\n\nmx_version\x18\x01 \x01(\t\x12\x37\n\x0emx_source_type\x18\x02 \x01(\x0e\x32\x1f.model_express.p2p.MxSourceType\x12\x12\n\nmodel_name\x18\x03 \x01(\t\x12>\n\x11\x62\x61\x63kend_framework\x18\x04 \x01(\x0e\x32#.model_express.p2p.BackendFramework\x12\x1c\n\x14tensor_parallel_size\x18\x05 \x01(\r\x12\x1e\n\x16pipeline_parallel_size\x18\x06 \x01(\r\x12\x1c\n\x14\x65xpert_parallel_size\x18\x07 \x01(\r\x12\r\n\x05\x64type\x18\x08 \x01(\t\x12\x14\n\x0cquantization\x18\t \x01(\t\x12P\n\x10\x65xtra_parameters\x18\n \x03(\x0b\x32\x36.model_express.p2p.SourceIdentity.ExtraParametersEntry\x12\x10\n\x08revision\x18\x0b \x01(\t\x1a\x36\n\x14\x45xtraParametersEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"^\n\x10TensorDescriptor\x12\x0c\n\x04name\x18\x01 \x01(\t\x12\x0c\n\x04\x61\x64\x64r\x18\x02 \x01(\x04\x12\x0c\n\x04size\x18\x03 \x01(\x04\x12\x11\n\tdevice_id\x18\x04 \x01(\r\x12\r\n\x05\x64type\x18\x05 \x01(\t\"\xc0\x02\n\x0eWorkerMetadata\x12\x13\n\x0bworker_rank\x18\x01 \x01(\r\x12\x17\n\rnixl_metadata\x18\x02 \x01(\x0cH\x00\x12$\n\x1atransfer_engine_session_id\x18\n \x01(\tH\x00\x12\x34\n\x07tensors\x18\x03 \x03(\x0b\x32#.model_express.p2p.TensorDescriptor\x12/\n\x06status\x18\x04 \x01(\x0e\x32\x1f.model_express.p2p.SourceStatus\x12\x12\n\nupdated_at\x18\x05 \x01(\x03\x12\x19\n\x11metadata_endpoint\x18\x06 \x01(\t\x12\x12\n\nagent_name\x18\x07 \x01(\t\x12\x1c\n\x14worker_grpc_endpoint\x18\x08 \x01(\tB\x12\n\x10\x62\x61\x63kend_metadata\"0\n\x18GetTensorManifestRequest\x12\x14\n\x0cmx_source_id\x18\x01 \x01(\t\"\xab\x01\n\x19GetTensorManifestResponse\x12\x34\n\x07tensors\x18\x01 \x03(\x0b\x32#.model_express.p2p.TensorDescriptor\x12\x14\n\x0cmx_source_id\x18\x02 \x01(\t\x12\x19\n\x11metadata_endpoint\x18\x03 \x01(\t\x12\x12\n\nagent_name\x18\x04 \x01(\t\x12\x13\n\x0bworker_rank\x18\x05 \x01(\r\"\x93\x01\n\x16PublishMetadataRequest\x12\x33\n\x08identity\x18\x01 \x01(\x0b\x32!.model_express.p2p.SourceIdentity\x12\x31\n\x06worker\x18\x02 \x01(\x0b\x32!.model_express.p2p.WorkerMetadata\x12\x11\n\tworker_id\x18\x03 \x01(\t\"d\n\x17PublishMetadataResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t\x12\x14\n\x0cmx_source_id\x18\x03 \x01(\t\x12\x11\n\tworker_id\x18\x04 \x01(\t\"e\n\x11SourceInstanceRef\x12\x14\n\x0cmx_source_id\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\t\x12\x12\n\nmodel_name\x18\x03 \x01(\t\x12\x13\n\x0bworker_rank\x18\x04 \x01(\r\"\x98\x01\n\x12ListSourcesRequest\x12\x33\n\x08identity\x18\x01 \x01(\x0b\x32!.model_express.p2p.SourceIdentity\x12;\n\rstatus_filter\x18\x02 \x01(\x0e\x32\x1f.model_express.p2p.SourceStatusH\x00\x88\x01\x01\x42\x10\n\x0e_status_filter\"N\n\x13ListSourcesResponse\x12\x37\n\tinstances\x18\x01 \x03(\x0b\x32$.model_express.p2p.SourceInstanceRef\"=\n\x12GetMetadataRequest\x12\x14\n\x0cmx_source_id\x18\x01 \x01(\t\x12\x11\n\tworker_id\x18\x02 \x01(\t\"\xb5\x01\n\x13GetMetadataResponse\x12\r\n\x05\x66ound\x18\x01 \x01(\x08\x12\x31\n\x06worker\x18\x02 \x01(\x0b\x32!.model_express.p2p.WorkerMetadata\x12\x14\n\x0cmx_source_id\x18\x03 \x01(\t\x12\x11\n\tworker_id\x18\x04 \x01(\t\x12\x33\n\x08identity\x18\x05 \x01(\x0b\x32!.model_express.p2p.SourceIdentity\"\x84\x01\n\x13UpdateStatusRequest\x12\x14\n\x0cmx_source_id\x18\x01 \x01(\t\x12\x13\n\x0bworker_rank\x18\x02 \x01(\r\x12/\n\x06status\x18\x03 \x01(\x0e\x32\x1f.model_express.p2p.SourceStatus\x12\x11\n\tworker_id\x18\x04 \x01(\t\"8\n\x14UpdateStatusResponse\x12\x0f\n\x07success\x18\x01 \x01(\x08\x12\x0f\n\x07message\x18\x02 \x01(\t*\x8a\x01\n\x10\x42\x61\x63kendFramework\x12\x1d\n\x19\x42\x41\x43KEND_FRAMEWORK_UNKNOWN\x10\x00\x12\x1a\n\x16\x42\x41\x43KEND_FRAMEWORK_VLLM\x10\x01\x12\x1c\n\x18\x42\x41\x43KEND_FRAMEWORK_SGLANG\x10\x02\x12\x1d\n\x19\x42\x41\x43KEND_FRAMEWORK_TRT_LLM\x10\x03*b\n\x0cMxSourceType\x12\x1a\n\x16MX_SOURCE_TYPE_WEIGHTS\x10\x00\x12\x17\n\x13MX_SOURCE_TYPE_LORA\x10\x01\x12\x1d\n\x19MX_SOURCE_TYPE_CUDA_GRAPH\x10\x02*{\n\x0cSourceStatus\x12\x19\n\x15SOURCE_STATUS_UNKNOWN\x10\x00\x12\x1e\n\x1aSOURCE_STATUS_INITIALIZING\x10\x01\x12\x17\n\x13SOURCE_STATUS_READY\x10\x02\x12\x17\n\x13SOURCE_STATUS_STALE\x10\x03\x32\x93\x03\n\nP2pService\x12h\n\x0fPublishMetadata\x12).model_express.p2p.PublishMetadataRequest\x1a*.model_express.p2p.PublishMetadataResponse\x12\\\n\x0bListSources\x12%.model_express.p2p.ListSourcesRequest\x1a&.model_express.p2p.ListSourcesResponse\x12\\\n\x0bGetMetadata\x12%.model_express.p2p.GetMetadataRequest\x1a&.model_express.p2p.GetMetadataResponse\x12_\n\x0cUpdateStatus\x12&.model_express.p2p.UpdateStatusRequest\x1a\'.model_express.p2p.UpdateStatusResponse2\x7f\n\rWorkerService\x12n\n\x11GetTensorManifest\x12+.model_express.p2p.GetTensorManifestRequest\x1a,.model_express.p2p.GetTensorManifestResponseb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -36,12 +36,12 @@ DESCRIPTOR._loaded_options = None _globals['_SOURCEIDENTITY_EXTRAPARAMETERSENTRY']._loaded_options = None _globals['_SOURCEIDENTITY_EXTRAPARAMETERSENTRY']._serialized_options = b'8\001' - _globals['_BACKENDFRAMEWORK']._serialized_start=2118 - _globals['_BACKENDFRAMEWORK']._serialized_end=2256 - _globals['_MXSOURCETYPE']._serialized_start=2258 - _globals['_MXSOURCETYPE']._serialized_end=2356 - _globals['_SOURCESTATUS']._serialized_start=2358 - _globals['_SOURCESTATUS']._serialized_end=2481 + _globals['_BACKENDFRAMEWORK']._serialized_start=2171 + _globals['_BACKENDFRAMEWORK']._serialized_end=2309 + _globals['_MXSOURCETYPE']._serialized_start=2311 + _globals['_MXSOURCETYPE']._serialized_end=2409 + _globals['_SOURCESTATUS']._serialized_start=2411 + _globals['_SOURCESTATUS']._serialized_end=2534 _globals['_SOURCEIDENTITY']._serialized_start=33 _globals['_SOURCEIDENTITY']._serialized_end=495 _globals['_SOURCEIDENTITY_EXTRAPARAMETERSENTRY']._serialized_start=441 @@ -67,13 +67,13 @@ _globals['_GETMETADATAREQUEST']._serialized_start=1730 _globals['_GETMETADATAREQUEST']._serialized_end=1791 _globals['_GETMETADATARESPONSE']._serialized_start=1794 - _globals['_GETMETADATARESPONSE']._serialized_end=1922 - _globals['_UPDATESTATUSREQUEST']._serialized_start=1925 - _globals['_UPDATESTATUSREQUEST']._serialized_end=2057 - _globals['_UPDATESTATUSRESPONSE']._serialized_start=2059 - _globals['_UPDATESTATUSRESPONSE']._serialized_end=2115 - _globals['_P2PSERVICE']._serialized_start=2484 - _globals['_P2PSERVICE']._serialized_end=2887 - _globals['_WORKERSERVICE']._serialized_start=2889 - _globals['_WORKERSERVICE']._serialized_end=3016 + _globals['_GETMETADATARESPONSE']._serialized_end=1975 + _globals['_UPDATESTATUSREQUEST']._serialized_start=1978 + _globals['_UPDATESTATUSREQUEST']._serialized_end=2110 + _globals['_UPDATESTATUSRESPONSE']._serialized_start=2112 + _globals['_UPDATESTATUSRESPONSE']._serialized_end=2168 + _globals['_P2PSERVICE']._serialized_start=2537 + _globals['_P2PSERVICE']._serialized_end=2940 + _globals['_WORKERSERVICE']._serialized_start=2942 + _globals['_WORKERSERVICE']._serialized_end=3069 # @@protoc_insertion_point(module_scope) diff --git a/modelexpress_client/python/modelexpress/refit_receiver.py b/modelexpress_client/python/modelexpress/refit_receiver.py new file mode 100644 index 00000000..117d469e --- /dev/null +++ b/modelexpress_client/python/modelexpress/refit_receiver.py @@ -0,0 +1,537 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Inference-side weight receiver for RL refit via ModelExpress. + +Wraps NixlTransferManager + MxClient to discover updated weights +published by the training side, pull them via RDMA, and yield +``(name, tensor)`` pairs compatible with vLLM's ``model.load_weights()``. + +Typical usage in a vLLM worker extension:: + + receiver = MxRefitReceiver("inference-0", device_id=0, mx_server_url="mx-server:8001") + receiver.initialize(model_tensors=dict(model.named_parameters())) + + source = receiver.poll_for_source(model_name="Qwen/Qwen2.5-1.5B") + if source is not None: + for name, tensor in receiver.receive_weights(source): + ... # load into model +""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass +from typing import Any, Iterator + +import torch + +from .client import MxClient +from .nixl_transfer import NixlTransferManager, is_nixl_available +from .types import TensorDescriptor +from . import p2p_pb2 + +logger = logging.getLogger("modelexpress.refit_receiver") + + +# Maps the dtype string the publisher writes into TensorDescriptor.dtype to a +# torch.dtype. Module-scope so all receiver paths share one definition (and so +# we don't rebuild it on every receive_weights_scratch call). +_DTYPE_MAP: dict[str, torch.dtype] = { + "torch.bfloat16": torch.bfloat16, + "torch.float16": torch.float16, + "torch.float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, +} + + +@dataclass +class SourceRef: + """Lightweight handle to a discovered weight source on the MX Server.""" + mx_source_id: str + worker_id: str + model_name: str + worker_rank: int + training_step: int + + +@dataclass +class TransferStats: + """Structured per-receive metrics. Populated after each receive_weights* + call and exposed on ``MxRefitReceiver.last_stats`` so benchmarks and + operators can query timing/throughput without parsing log lines. + + Fields: + bytes_received: total bytes pulled via NIXL RDMA in this call. + Excludes the v2 sidecar (`__mx_v2_meta__`) since that is + filtered out before RDMA register. + bytes_skipped: bytes deliberately skipped by NIXL (e.g. tensors + with mismatched destination buffers). + tensors_received: count of real tensors (sidecars excluded). + elapsed_seconds: wall time of the underlying NIXL transfer. + Does NOT include discovery latency (catalog GetMetadata). + bandwidth_gbps: derived (``bytes_received * 8 / elapsed / 1e9``). + ``0.0`` when elapsed is 0. + discovery_seconds: wall time of MX-server discovery + metadata + fetch (catalog round-trip). Tracked separately from the + RDMA transfer so callers can compare control-plane vs + data-plane latencies — important for elastic-scale-up. + path: which receive path was used. One of: + ``"pre_registered"`` (receive_weights), ``"scratch"`` + (receive_weights_scratch), ``"from_metadata"`` + (receive_weights_from_metadata). + training_step: the source's training step (== version), echoed + back for log/dashboard joins. + source_worker_rank: which publisher rank the bytes came from. + For multi-source `receive_via_plan` this is None on the + aggregate stats and populated on the per-contribution + inner stats. + """ + + bytes_received: int = 0 + bytes_skipped: int = 0 + tensors_received: int = 0 + elapsed_seconds: float = 0.0 + bandwidth_gbps: float = 0.0 + discovery_seconds: float = 0.0 + path: str = "" + training_step: int = 0 + source_worker_rank: int | None = None + + def update_bandwidth(self) -> None: + """Recompute bandwidth from bytes_received + elapsed_seconds.""" + self.bandwidth_gbps = ( + (self.bytes_received * 8) / (self.elapsed_seconds * 1e9) + if self.elapsed_seconds > 0 + else 0.0 + ) + + +class MxRefitReceiver: + """Receives updated weights from a training process via ModelExpress RDMA. + + One instance per GPU rank on the inference side. Discovers training + sources via the MX Server, pulls weight tensors over NIXL RDMA, + and yields them for ``model.load_weights()``. + + Args: + agent_name: Unique NIXL agent name (e.g. ``"inference-rank-0"``). + device_id: CUDA device index for this inference rank. + mx_server_url: gRPC address of the ModelExpress server. + listen_port: Optional NIXL listen port for P2P metadata exchange. + """ + + def __init__( + self, + agent_name: str, + device_id: int, + mx_server_url: str = "localhost:8001", + listen_port: int | None = None, + ): + self._agent_name = agent_name + self._device_id = device_id + self._mx_server_url = mx_server_url + self._listen_port = listen_port + + # Last-call metrics; populated after each receive_weights* call. + # Callers (benchmarks, dashboards, the v2 wrapper) can read this + # after each invocation without parsing logs. + self.last_stats: TransferStats = TransferStats() + self.history: list[TransferStats] = [] # full per-call history; appended after each call + + self._nixl: NixlTransferManager | None = None + self._client: MxClient | None = None + self._initialized = False + self._current_step = -1 + + @property + def current_step(self) -> int: + """The most recently received training step.""" + return self._current_step + + def initialize(self, model_tensors: dict[str, torch.Tensor] | None = None) -> None: + """Initialize NIXL agent, MX client, and optionally register receive buffers. + + Args: + model_tensors: If provided, registers these tensors with NIXL as + receive buffers. For tensor-name-matched transfers, the source's + tensors are written directly into these buffers. If *None*, + the caller must register tensors separately. + """ + if not is_nixl_available(): + raise RuntimeError( + "NIXL is not available. Install nixl or build from source." + ) + + self._nixl = NixlTransferManager( + agent_name=self._agent_name, + device_id=self._device_id, + listen_port=self._listen_port, + ) + self._nixl.initialize() + + if model_tensors is not None: + self._nixl.register_tensors(model_tensors) + logger.info( + f"Registered {len(model_tensors)} receive buffers with NIXL" + ) + + self._client = MxClient(server_url=self._mx_server_url) + self._initialized = True + logger.info( + f"MxRefitReceiver initialized: agent={self._agent_name}, " + f"device={self._device_id}" + ) + + def poll_for_source( + self, + model_name: str, + min_step: int | None = None, + status_filter: int = p2p_pb2.SOURCE_STATUS_READY, + timeout_seconds: float = 0, + ) -> SourceRef | None: + """Check the MX Server for a training source with updated weights. + + Args: + model_name: Model name to filter on (must match publisher's identity). + min_step: If set, only return sources with ``training_step >= min_step``. + Defaults to ``current_step + 1`` to only find newer versions. + timeout_seconds: If > 0, poll repeatedly until a source is found + or timeout is reached. If 0, check once and return immediately. + + Returns: + A :class:`SourceRef` if a matching source was found, else *None*. + + Note: + ``training_step`` is published in ``SourceIdentity.extra_parameters`` + but ``ListSourcesResponse.instances`` only carries + ``SourceInstanceRef`` (no ``extra_parameters``). To honor the + ``min_step`` contract, this method does a per-candidate + ``get_metadata`` lookup so it can read ``training_step`` from the + publisher's full ``SourceIdentity``. A future server-side fix + (adding ``training_step`` to ``SourceInstanceRef``) will let us + drop the extra round-trip. + """ + if not self._initialized: + raise RuntimeError("Call initialize() before poll_for_source()") + + if min_step is None: + min_step = self._current_step + 1 + + deadline = time.perf_counter() + timeout_seconds + + while True: + try: + response = self._client.list_sources( + status_filter=status_filter, + ) + except Exception as e: # noqa: BLE001 — log + retry on transient gRPC error + logger.warning(f"list_sources failed: {e}") + if time.perf_counter() >= deadline: + return None + time.sleep(0.5) + continue + + for instance in response.instances: + if instance.model_name != model_name: + continue + + # Resolve training_step from the publisher's SourceIdentity so + # min_step can be enforced. Skip candidates whose metadata is + # unreachable or whose step is below the threshold. + step = self._resolve_training_step(instance) + if step is None or step < min_step: + continue + + return SourceRef( + mx_source_id=instance.mx_source_id, + worker_id=instance.worker_id, + model_name=instance.model_name, + worker_rank=instance.worker_rank, + training_step=step, + ) + + if time.perf_counter() >= deadline: + return None + time.sleep(0.5) + + def _resolve_training_step(self, instance: Any) -> int | None: + """Fetch the publisher's ``training_step`` from MX Server metadata. + + ``SourceInstanceRef`` (returned by ``list_sources``) doesn't expose + ``extra_parameters``, so we do a follow-up ``get_metadata`` to read + ``training_step`` from ``SourceIdentity.extra_parameters``. Returns + ``None`` if the metadata isn't available or the step can't be + parsed — caller should treat this as "skip candidate". + """ + try: + meta = self._client.get_metadata(instance.mx_source_id, instance.worker_id) + except Exception as e: # noqa: BLE001 — gRPC failures are per-candidate, not fatal + logger.debug(f"get_metadata failed for {instance.worker_id}: {e}") + return None + if not getattr(meta, "found", False): + return None + identity = getattr(meta, "identity", None) + if identity is None: + return None + extra = getattr(identity, "extra_parameters", None) or {} + raw = extra.get("training_step") if hasattr(extra, "get") else None + if raw is None: + return None + try: + return int(raw) + except (TypeError, ValueError): + logger.debug(f"training_step={raw!r} not parseable as int; skipping") + return None + + def receive_weights( + self, + source: SourceRef, + timeout_seconds: float = 300.0, + ) -> Iterator[tuple[str, torch.Tensor]]: + """Receive weights from a discovered source via NIXL RDMA. + + Fetches the source's NIXL metadata and tensor descriptors from the + MX Server, establishes an RDMA connection, and transfers weight + tensors into locally registered buffers. + + Args: + source: A :class:`SourceRef` obtained from :meth:`poll_for_source`. + timeout_seconds: Maximum time to wait for the RDMA transfer. + + Yields: + ``(name, tensor)`` pairs suitable for ``model.load_weights()``. + """ + if not self._initialized: + raise RuntimeError("Call initialize() before receive_weights()") + + discovery_start = time.monotonic() + meta_resp = self._client.get_metadata( + mx_source_id=source.mx_source_id, + worker_id=source.worker_id, + ) + discovery_seconds = time.monotonic() - discovery_start + if not meta_resp.found: + raise RuntimeError( + f"Source {source.mx_source_id}/{source.worker_id} not found on MX Server" + ) + + worker = meta_resp.worker + # Filter out V2 sidecar TensorDescriptors (name="__mx_v2_meta__", + # addr=0, size=0). The V2 publisher uses them to smuggle metadata + # past the MX server's field-dropping; they aren't real RDMA + # targets. Leaving them in the source_tensors list propagates a + # (0,0,0) descriptor into prep_xfer_dlist which UCX rejects. + source_tensors = [ + TensorDescriptor( + name=t.name, + addr=t.addr, + size=t.size, + device_id=t.device_id, + dtype=t.dtype, + ) + for t in worker.tensors + if not t.name.startswith("__mx_") and t.size > 0 + ] + + transferred, skipped, elapsed = self._nixl.receive_from_source( + source_metadata=worker.nixl_metadata, + source_tensors=source_tensors, + timeout_seconds=timeout_seconds, + ) + + stats = TransferStats( + bytes_received=transferred, + bytes_skipped=skipped, + tensors_received=len(source_tensors), + elapsed_seconds=elapsed, + discovery_seconds=discovery_seconds, + path="pre_registered", + training_step=source.training_step, + source_worker_rank=source.worker_rank, + ) + stats.update_bandwidth() + self.last_stats = stats + self.history.append(stats) + + logger.info( + f"RDMA transfer complete: {transferred / 1e9:.2f} GB, " + f"{len(source_tensors)} tensors, {elapsed:.2f}s, " + f"{stats.bandwidth_gbps:.1f} Gbps (step={source.training_step})" + ) + + self._current_step = source.training_step + + for td in source_tensors: + if td.name in self._nixl._tensors: + yield td.name, self._nixl._tensors[td.name] + + def receive_weights_scratch( + self, + source: SourceRef, + timeout_seconds: float = 300.0, + tensor_shapes: dict[str, tuple[int, ...]] | None = None, + ) -> Iterator[tuple[str, torch.Tensor]]: + """Receive weights into scratch GPU buffers via NIXL RDMA. + + Unlike :meth:`receive_weights` which requires pre-registered model + buffers with matching tensor names, this method allocates temporary + GPU tensors that match the source's layout, transfers via RDMA, and + yields the results. The caller feeds these through + ``model.load_weights()`` which handles name mapping and tensor fusion. + + This is the correct approach when the source (trainer) publishes + HuggingFace-format weights but the target (vLLM) uses fused internal + parameter names. + + Args: + source: A :class:`SourceRef` obtained from :meth:`poll_for_source`. + timeout_seconds: Maximum time to wait for the RDMA transfer. + + Yields: + ``(name, tensor)`` pairs in HF checkpoint format. + """ + if not self._initialized: + raise RuntimeError("Call initialize() before receive_weights_scratch()") + + discovery_start = time.monotonic() + meta_resp = self._client.get_metadata( + mx_source_id=source.mx_source_id, + worker_id=source.worker_id, + ) + discovery_seconds = time.monotonic() - discovery_start + if not meta_resp.found: + raise RuntimeError( + f"Source {source.mx_source_id}/{source.worker_id} not found on MX Server" + ) + + worker = meta_resp.worker + # Filter out V2 sidecar TensorDescriptors (name="__mx_v2_meta__", + # addr=0, size=0). The V2 publisher uses them to smuggle metadata + # past the MX server's field-dropping; they aren't real RDMA + # targets. Leaving them in the source_tensors list propagates a + # (0,0,0) descriptor into prep_xfer_dlist which UCX rejects. + source_tensors = [ + TensorDescriptor( + name=t.name, + addr=t.addr, + size=t.size, + device_id=t.device_id, + dtype=t.dtype, + ) + for t in worker.tensors + if not t.name.startswith("__mx_") and t.size > 0 + ] + + scratch_tensors: dict[str, torch.Tensor] = {} + scratch_shapes: dict[str, tuple[int, ...]] = {} + for td in source_tensors: + dt = _DTYPE_MAP.get(td.dtype, torch.bfloat16) + elem_size = torch.tensor([], dtype=dt).element_size() + numel = td.size // elem_size + scratch_tensors[td.name] = torch.empty( + numel, dtype=dt, device=f"cuda:{self._device_id}" + ) + scratch_shapes[td.name] = (numel,) + + logger.info( + f"Allocated {len(scratch_tensors)} scratch buffers " + f"({sum(t.numel() * t.element_size() for t in scratch_tensors.values()) / 1e9:.2f} GB)" + ) + + self._nixl.register_tensors(scratch_tensors) + + transferred, skipped, elapsed = self._nixl.receive_from_source( + source_metadata=worker.nixl_metadata, + source_tensors=source_tensors, + timeout_seconds=timeout_seconds, + ) + + stats = TransferStats( + bytes_received=transferred, + bytes_skipped=skipped, + tensors_received=len(source_tensors), + elapsed_seconds=elapsed, + discovery_seconds=discovery_seconds, + path="scratch", + training_step=source.training_step, + source_worker_rank=source.worker_rank, + ) + stats.update_bandwidth() + self.last_stats = stats + self.history.append(stats) + + logger.info( + f"RDMA transfer complete: {transferred / 1e9:.2f} GB, " + f"{len(source_tensors)} tensors, {elapsed:.2f}s, " + f"{stats.bandwidth_gbps:.1f} Gbps (step={source.training_step})" + ) + + self._current_step = source.training_step + + for name, tensor in scratch_tensors.items(): + if tensor_shapes and name in tensor_shapes: + tensor = tensor.view(tensor_shapes[name]) + yield name, tensor + + def receive_weights_from_metadata( + self, + nixl_metadata: bytes, + source_tensors: list[TensorDescriptor], + training_step: int, + timeout_seconds: float = 300.0, + ) -> Iterator[tuple[str, torch.Tensor]]: + """Receive weights when metadata is already available (bypasses MX Server query). + + Useful when the orchestrator passes metadata directly instead of + having the worker poll the MX Server. + """ + if not self._initialized: + raise RuntimeError("Call initialize() first") + + transferred, skipped, elapsed = self._nixl.receive_from_source( + source_metadata=nixl_metadata, + source_tensors=source_tensors, + timeout_seconds=timeout_seconds, + ) + + stats = TransferStats( + bytes_received=transferred, + bytes_skipped=skipped, + tensors_received=len(source_tensors), + elapsed_seconds=elapsed, + discovery_seconds=0.0, # bypass-mode: no catalog roundtrip + path="from_metadata", + training_step=training_step, + source_worker_rank=None, # not derivable from raw metadata + ) + stats.update_bandwidth() + self.last_stats = stats + self.history.append(stats) + + logger.info( + f"RDMA transfer (direct metadata): {transferred / 1e9:.2f} GB, " + f"{len(source_tensors)} tensors, {elapsed:.2f}s, " + f"{stats.bandwidth_gbps:.1f} Gbps" + ) + + self._current_step = training_step + + for td in source_tensors: + if td.name in self._nixl._tensors: + yield td.name, self._nixl._tensors[td.name] + + def shutdown(self) -> None: + """Release NIXL agent and close gRPC channel.""" + if self._nixl is not None: + self._nixl.shutdown() + self._nixl = None + if self._client is not None: + self._client.close() + self._client = None + self._initialized = False + logger.info(f"MxRefitReceiver shut down: {self._agent_name}") diff --git a/modelexpress_client/python/modelexpress/shape_descriptors.py b/modelexpress_client/python/modelexpress/shape_descriptors.py new file mode 100644 index 00000000..d35366a4 --- /dev/null +++ b/modelexpress_client/python/modelexpress/shape_descriptors.py @@ -0,0 +1,374 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""DTensor placement → MX wire bridge for v2 NemoRL integration. + +Translates PyTorch's ``distributed.tensor.placement_types`` plus optional +MoE expert axis information into a small JSON payload we stash in +``SourceIdentity.extra_parameters[shape_registry]``. + +The v2 design (`pensieve/RL/NemoRL/04_design_v2_moe_rank_to_rank.md`) wants +an explicit, versioned ``ShapeRegistry`` proto on the MX server. To unblock +the prototype without touching the proto + Rust server code paths, this +module implements the registry as a JSON document attached to each +trainer's ``WorkerMetadata.extra_parameters``. Receivers consult it to +know each tensor's: + + - global shape (un-sharded) + - dtype + - placement (REPLICATE | SHARD axis | PARTIAL axis) + - shard range owned by *this* trainer's rank + - whether it's a MoE expert tensor + - which expert IDs this rank owns (when applicable) + +The format is intentionally JSON-shaped so the eventual proto migration +is a near-mechanical lift. See `pensieve/RL/NemoRL/05_mx_helpers_needed.md` +for the proto-shape we'd graduate to. +""" + +from __future__ import annotations + +import dataclasses +import json +from typing import Any + +import torch + +try: + from torch.distributed.tensor.placement_types import ( + Partial, + Replicate, + Shard, + ) + + _DTensor_AVAILABLE = True +except ImportError: # torch < 2.4 or non-distributed builds + Partial = Replicate = Shard = None # type: ignore[misc, assignment] + _DTensor_AVAILABLE = False + + +# Sentinel placement kinds. Match the eventual proto enum exactly. +PLACEMENT_REPLICATE = "REPLICATE" +PLACEMENT_SHARD = "SHARD" +PLACEMENT_PARTIAL = "PARTIAL" + +# Canonical compile targets. The string set is open — frameworks can introduce +# new targets without an MX bump — but receivers should treat targets they +# don't recognise as "do not consume". Always pair a non-HF target with +# ``compile_metadata`` describing the engine/kernel/quant choices that drive +# byte-level compatibility. +COMPILE_TARGET_HF_RAW = "hf_raw" +COMPILE_TARGET_VLLM_FUSED = "vllm_fused" +COMPILE_TARGET_DEEPGEMM_FP8 = "deep_gemm_fp8" +COMPILE_TARGET_CUTLASS_FP8 = "cutlass_fp8" +COMPILE_TARGET_TRTLLM = "trtllm" + + +@dataclasses.dataclass +class TensorDescriptorV2: + """Per-tensor shape + placement + expert + compile metadata. + + Fields: + name: tensor's qualified name in ``model.state_dict()``. + global_shape: shape of the *un-sharded* tensor across all DP/TP ranks. + dtype: torch dtype string (``"bfloat16"``, ``"float16"``, ...). + placement_kind: one of ``PLACEMENT_*``. + shard_axis: only meaningful if ``placement_kind == PLACEMENT_SHARD`` or + ``PLACEMENT_PARTIAL``. + local_shard_range: ``(start, end)`` along ``shard_axis`` owned by the + publisher's rank. ``None`` when ``REPLICATE``. + is_expert: whether this tensor's leading axis is the MoE expert axis. + expert_axis: index of the expert axis (only when ``is_expert``). + owned_expert_ids: expert IDs the publisher's rank owns. + compile_target: kernel layout label this tensor's bytes are encoded + for. ``"hf_raw"`` is the safe default — the trainer's HF + state-dict view, no post-processing. Other publishers may emit + ``"deep_gemm_fp8"``, ``"cutlass_fp8"``, ``"vllm_fused"``, etc. + Receivers filter on this via + :meth:`MxV2RefitReceiver.discover_v2_sources` so they only consume + sources whose layout they can decode. + compile_metadata: free-form key/value blob describing the specific + compile invocation (e.g. ``{"engine": "DeepGemm", "version": + "0.1.7", "block_size": 128, "scale_layout": "K-major"}``). + Receivers should treat a mismatch on any byte-affecting field as + a hard reject even if ``compile_target`` matches. + """ + + name: str + global_shape: tuple[int, ...] + dtype: str + placement_kind: str = PLACEMENT_REPLICATE + shard_axis: int = 0 + local_shard_range: tuple[int, int] | None = None + is_expert: bool = False + expert_axis: int = 0 + owned_expert_ids: tuple[int, ...] = () + compile_target: str = COMPILE_TARGET_HF_RAW + compile_metadata: dict[str, Any] = dataclasses.field(default_factory=dict) + + def to_dict(self) -> dict[str, Any]: + d: dict[str, Any] = { + "name": self.name, + "global_shape": list(self.global_shape), + "dtype": self.dtype, + "placement_kind": self.placement_kind, + "shard_axis": self.shard_axis, + } + if self.local_shard_range is not None: + d["local_shard_range"] = list(self.local_shard_range) + if self.is_expert: + d["is_expert"] = True + d["expert_axis"] = self.expert_axis + d["owned_expert_ids"] = list(self.owned_expert_ids) + if self.compile_target != COMPILE_TARGET_HF_RAW: + d["compile_target"] = self.compile_target + if self.compile_metadata: + d["compile_metadata"] = dict(self.compile_metadata) + return d + + @classmethod + def from_dict(cls, d: dict[str, Any]) -> "TensorDescriptorV2": + rng = d.get("local_shard_range") + return cls( + name=d["name"], + global_shape=tuple(d["global_shape"]), + dtype=d["dtype"], + placement_kind=d.get("placement_kind", PLACEMENT_REPLICATE), + shard_axis=int(d.get("shard_axis", 0)), + local_shard_range=tuple(rng) if rng is not None else None, + is_expert=bool(d.get("is_expert", False)), + expert_axis=int(d.get("expert_axis", 0)), + owned_expert_ids=tuple(d.get("owned_expert_ids", [])), + compile_target=str(d.get("compile_target", COMPILE_TARGET_HF_RAW)), + compile_metadata=dict(d.get("compile_metadata", {})), + ) + + +def _dtype_to_str(dtype: torch.dtype) -> str: + s = str(dtype) + return s[len("torch.") :] if s.startswith("torch.") else s + + +def describe_tensor( + *, + name: str, + tensor: torch.Tensor, + rank: int, + fsdp_world_size: int, + is_expert: bool = False, + expert_axis: int = 0, + owned_expert_ids: tuple[int, ...] | set[int] | list[int] = (), + compile_target: str = COMPILE_TARGET_HF_RAW, + compile_metadata: dict[str, Any] | None = None, +) -> TensorDescriptorV2: + """Build a ``TensorDescriptorV2`` from a tensor + rank context. + + For a regular ``torch.Tensor`` (no ``placements`` attribute) this yields + a ``REPLICATE`` descriptor. For a DTensor it inspects ``tensor.placements`` + and emits the matching ``SHARD`` / ``PARTIAL`` descriptor; the local shard + range is computed assuming an even shard layout (every rank's local size + is the same). Uneven shards are not supported in v0 — the caller should + pre-pad or fall back to bucket pack. + + The returned descriptor refers to the **global** shape: i.e. the + un-sharded full tensor. ``local_shard_range[0:1]`` describes the slice + along ``shard_axis`` that this rank owns. + """ + dtype_str = _dtype_to_str(tensor.dtype) + metadata = dict(compile_metadata) if compile_metadata else {} + placements = getattr(tensor, "placements", None) + if not _DTensor_AVAILABLE or not placements: + return TensorDescriptorV2( + name=name, + global_shape=tuple(int(s) for s in tensor.shape), + dtype=dtype_str, + placement_kind=PLACEMENT_REPLICATE, + is_expert=is_expert, + expert_axis=expert_axis, + owned_expert_ids=tuple(sorted(owned_expert_ids)), + compile_target=compile_target, + compile_metadata=metadata, + ) + + if len(placements) != 1: + # v0 supports only 1D meshes (FSDP only or TP only). HSDP / 2D meshes + # are deferred — see `04_design_v2_moe_rank_to_rank.md` §7. + raise NotImplementedError( + f"DTensor with {len(placements)}D mesh is not supported in v0; " + f"only 1D meshes are. tensor={name}" + ) + p = placements[0] + + if isinstance(p, Replicate): + return TensorDescriptorV2( + name=name, + global_shape=tuple(int(s) for s in tensor.shape), + dtype=dtype_str, + placement_kind=PLACEMENT_REPLICATE, + is_expert=is_expert, + expert_axis=expert_axis, + owned_expert_ids=tuple(sorted(owned_expert_ids)), + compile_target=compile_target, + compile_metadata=metadata, + ) + + if isinstance(p, Shard): + # ``tensor.shape`` semantics differ between real DTensors and plain + # tensors: + # * Real ``torch.distributed.tensor.DTensor.shape`` is the GLOBAL + # (un-sharded) shape; the local view is ``tensor.to_local().shape``. + # * A plain tensor (or a stand-in object with ``.placements`` but no + # ``.to_local``) has shape == local-view by construction. + # Compute global vs local accordingly. + try: + from torch.distributed.tensor import DTensor as _RealDTensor + except ImportError: # pragma: no cover — handled at module import + _RealDTensor = None # type: ignore[assignment] + + if _RealDTensor is not None and isinstance(tensor, _RealDTensor): + global_shape = tuple(int(s) for s in tensor.shape) + try: + local_extent = int(tensor.to_local().shape[p.dim]) + except Exception: + # Fallback: assume even sharding. + local_extent = global_shape[p.dim] // fsdp_world_size + else: + local_shape = list(int(s) for s in tensor.shape) + global_shape_list = list(local_shape) + global_shape_list[p.dim] = local_shape[p.dim] * fsdp_world_size + global_shape = tuple(global_shape_list) + local_extent = local_shape[p.dim] + + start = rank * local_extent + end = start + local_extent + return TensorDescriptorV2( + name=name, + global_shape=global_shape, + dtype=dtype_str, + placement_kind=PLACEMENT_SHARD, + shard_axis=int(p.dim), + local_shard_range=(start, end), + is_expert=is_expert, + expert_axis=expert_axis, + owned_expert_ids=tuple(sorted(owned_expert_ids)), + compile_target=compile_target, + compile_metadata=metadata, + ) + + if isinstance(p, Partial): + return TensorDescriptorV2( + name=name, + global_shape=tuple(int(s) for s in tensor.shape), + dtype=dtype_str, + placement_kind=PLACEMENT_PARTIAL, + shard_axis=int(p.dim) if hasattr(p, "dim") else 0, + compile_target=compile_target, + compile_metadata=metadata, + ) + + raise NotImplementedError(f"unsupported DTensor placement: {p!r}") + + +def even_expert_owner_map( + *, + num_experts: int, + ep_world_size: int, +) -> dict[int, set[int]]: + """Default linear shard: rank N owns experts ``[N*chunk : (N+1)*chunk)``. + + Used for sanity-checking that an MoE layout matches what the trainer + publishes vs what the inference rank expects to receive. + """ + if ep_world_size <= 0: + raise ValueError("ep_world_size must be positive") + if num_experts % ep_world_size != 0: + raise ValueError( + f"num_experts={num_experts} not divisible by ep_world_size={ep_world_size}; " + f"uneven expert assignment requires explicit owner map" + ) + chunk = num_experts // ep_world_size + return {r: set(range(r * chunk, (r + 1) * chunk)) for r in range(ep_world_size)} + + +def encode_registry( + descriptors: list[TensorDescriptorV2], + *, + version: int, + trainer_world_layout: str, +) -> str: + """Serialize a registry to a string for ``extra_parameters``.""" + payload = { + "version": int(version), + "trainer_world_layout": trainer_world_layout, + "tensors": [d.to_dict() for d in descriptors], + } + return json.dumps(payload, separators=(",", ":")) + + +def decode_registry(blob: str) -> dict[str, Any]: + """Inverse of ``encode_registry``. Returns ``{version, trainer_world_layout, tensors}``.""" + parsed = json.loads(blob) + parsed["tensors"] = [ + TensorDescriptorV2.from_dict(t) for t in parsed.get("tensors", []) + ] + return parsed + + +def encode_expert_set(expert_ids: set[int] | list[int] | tuple[int, ...]) -> str: + """Compact encoding for an expert id set, used in ``extra_parameters``.""" + return ",".join(str(int(e)) for e in sorted(set(expert_ids))) + + +def decode_expert_set(s: str | None) -> set[int]: + if not s: + return set() + return {int(p) for p in s.split(",") if p.strip()} + + +def compile_target_matches( + descriptor: TensorDescriptorV2, + *, + allowed_targets: set[str] | frozenset[str] | None, + required_metadata: dict[str, Any] | None = None, +) -> bool: + """Return True if ``descriptor`` is acceptable to a receiver. + + Args: + descriptor: the publisher-side descriptor (its ``compile_target`` and + ``compile_metadata`` describe how the bytes are laid out). + allowed_targets: receiver-side whitelist of compile-target strings the + receiver knows how to consume. ``None`` means "accept everything" + (back-compat shim — equivalent to the v0 behaviour). + required_metadata: optional key/value subset the descriptor's + ``compile_metadata`` must agree with byte-for-byte. Useful for + pinning e.g. ``{"block_size": 128, "scale_layout": "K-major"}`` + so a Cutlass receiver doesn't accept a DeepGemm-block-256 + publisher's bytes by mistake. + """ + if allowed_targets is not None and descriptor.compile_target not in allowed_targets: + return False + if required_metadata: + for key, want in required_metadata.items(): + if descriptor.compile_metadata.get(key) != want: + return False + return True + + +__all__ = [ + "COMPILE_TARGET_CUTLASS_FP8", + "COMPILE_TARGET_DEEPGEMM_FP8", + "COMPILE_TARGET_HF_RAW", + "COMPILE_TARGET_TRTLLM", + "COMPILE_TARGET_VLLM_FUSED", + "PLACEMENT_PARTIAL", + "PLACEMENT_REPLICATE", + "PLACEMENT_SHARD", + "TensorDescriptorV2", + "compile_target_matches", + "decode_expert_set", + "decode_registry", + "describe_tensor", + "encode_expert_set", + "encode_registry", + "even_expert_owner_map", +] diff --git a/modelexpress_client/python/modelexpress/training_publisher.py b/modelexpress_client/python/modelexpress/training_publisher.py new file mode 100644 index 00000000..2848ee6d --- /dev/null +++ b/modelexpress_client/python/modelexpress/training_publisher.py @@ -0,0 +1,319 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Training-side weight publisher for RL refit via ModelExpress. + +Wraps NixlTransferManager + MxClient to register updated model weights +on the training GPU and publish metadata to the MX Server so that +inference workers can discover and pull them via RDMA. + +Typical usage in an RL training loop:: + + publisher = MxTrainingPublisher("trainer-0", device_id=0, mx_server_url="mx-server:8001") + publisher.initialize(model_name="Qwen/Qwen2.5-1.5B") + + # After optimizer.step(): + for layer_idx, layer_sd in enumerate_layers(model): + publisher.publish_layer(layer_sd, layer_idx, step=training_step) + publisher.mark_ready() +""" + +from __future__ import annotations + +import logging +import uuid +from typing import Iterator + +import torch + +from .client import MxClient +from .nixl_transfer import NixlTransferManager, is_nixl_available +from .types import TensorDescriptor +from . import p2p_pb2 + +logger = logging.getLogger("modelexpress.training_publisher") + + +class MxTrainingPublisher: + """Publishes updated model weights from a training process to ModelExpress. + + One instance per GPU rank. On the training side, after each optimizer step, + the publisher registers weight tensors with NIXL and publishes metadata to + the MX Server. Inference workers discover the source via ``ListSources`` + and pull weights via RDMA. + + Args: + agent_name: Unique NIXL agent name (e.g. ``"trainer-rank-0"``). + device_id: CUDA device index for this training rank. + mx_server_url: gRPC address of the ModelExpress server. + listen_port: Optional NIXL listen port for P2P metadata exchange. + """ + + def __init__( + self, + agent_name: str, + device_id: int, + mx_server_url: str = "localhost:8001", + listen_port: int | None = None, + ): + self._agent_name = agent_name + self._device_id = device_id + self._mx_server_url = mx_server_url + self._listen_port = listen_port + + self._nixl: NixlTransferManager | None = None + self._client: MxClient | None = None + self._worker_id: str = str(uuid.uuid4()) + self._mx_source_id: str | None = None + self._model_name: str = "" + self._training_framework: str = "unknown" + self._initialized = False + self._registered = False + # Tracks which publish path has been used. publish_weights and + # publish_layer are mutually exclusive within a publisher's lifetime + # because they hold different sets of tensors registered with NIXL — + # mixing them silently invalidates the cached registration. None + # until first publish; "weights" or "layer" thereafter. + self._publish_mode: str | None = None + + @property + def mx_source_id(self) -> str | None: + return self._mx_source_id + + @property + def worker_id(self) -> str: + return self._worker_id + + def initialize( + self, + model_name: str, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + expert_parallel_size: int = 1, + dtype: str = "bfloat16", + training_framework: str = "unknown", + ) -> None: + """Initialize NIXL agent and MX client. + + Must be called before any publish operations. Sets up the source + identity that inference workers will use to filter compatible sources. + + Args: + training_framework: Identifier for the framework driving this + publisher (``"prime_rl"``, ``"verl"``, ``"nemo_rl"``, ...). + Surfaced in ``SourceIdentity.extra_parameters`` so consumers + can disambiguate sources from different frameworks publishing + to the same MX Server. Default ``"unknown"`` is intentional — + callers should pass an explicit value. + """ + if not is_nixl_available(): + raise RuntimeError( + "NIXL is not available. Install nixl or build from source." + ) + + self._model_name = model_name + self._training_framework = training_framework + self._identity_kwargs = dict( + model_name=model_name, + mx_source_type=p2p_pb2.MX_SOURCE_TYPE_WEIGHTS, + backend_framework=p2p_pb2.BACKEND_FRAMEWORK_UNKNOWN, + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + expert_parallel_size=expert_parallel_size, + dtype=dtype, + ) + + self._nixl = NixlTransferManager( + agent_name=self._agent_name, + device_id=self._device_id, + listen_port=self._listen_port, + ) + self._nixl.initialize() + + self._client = MxClient(server_url=self._mx_server_url) + self._initialized = True + logger.info( + f"MxTrainingPublisher initialized: agent={self._agent_name}, " + f"device={self._device_id}, model={model_name}, " + f"framework={training_framework}" + ) + + def _build_identity(self, step: int) -> p2p_pb2.SourceIdentity: + """Build a SourceIdentity proto with the current training step.""" + return p2p_pb2.SourceIdentity( + extra_parameters={ + "training_step": str(step), + "training_framework": self._training_framework, + }, + **self._identity_kwargs, + ) + + def _build_tensor_protos( + self, descriptors: list[TensorDescriptor] + ) -> list[p2p_pb2.TensorDescriptor]: + return [ + p2p_pb2.TensorDescriptor( + name=d.name, + addr=d.addr, + size=d.size, + device_id=d.device_id, + dtype=d.dtype, + ) + for d in descriptors + ] + + def publish_weights( + self, + named_tensors: dict[str, torch.Tensor], + step: int, + worker_rank: int = 0, + ) -> str: + """Register tensors with NIXL and publish metadata to MX Server. + + This is the all-at-once variant. For layer-by-layer streaming, + use :meth:`publish_layer` instead. + + NIXL memory regions are registered only on the first call since + parameter tensor addresses stay constant across optimizer steps. + Subsequent calls reuse the cached metadata and descriptors. + + Args: + named_tensors: Mapping of parameter name to GPU tensor. + step: Current training step (used for version tracking). + worker_rank: GPU rank of this worker within the training group. + + Returns: + The ``mx_source_id`` (16-char hex) assigned by the server. + """ + if not self._initialized: + raise RuntimeError("Call initialize() before publish_weights()") + if self._publish_mode == "layer": + raise RuntimeError( + "publish_weights() and publish_layer() are mutually exclusive: " + "this publisher has already been used in 'layer' mode " + "(publish_layer was called previously). Mixing the two paths " + "leaves NIXL holding only the most recently registered tensor " + "set, which silently invalidates earlier publishes. Use one " + "mode per publisher lifetime." + ) + self._publish_mode = "weights" + + if not self._registered: + self._nixl.register_tensors(named_tensors) + self._registered = True + logger.info( + f"Registered {len(named_tensors)} tensors with NIXL " + f"(metadata={len(self._nixl.nixl_metadata)} bytes)" + ) + metadata = self._nixl.nixl_metadata + descriptors = self._nixl.tensor_descriptors + + identity = self._build_identity(step) + worker_meta = p2p_pb2.WorkerMetadata( + worker_rank=worker_rank, + nixl_metadata=metadata, + tensors=self._build_tensor_protos(descriptors), + status=p2p_pb2.SOURCE_STATUS_INITIALIZING, + agent_name=self._agent_name, + ) + + self._mx_source_id = self._client.publish_metadata( + identity=identity, + worker=worker_meta, + worker_id=self._worker_id, + ) + logger.info( + f"Published {len(named_tensors)} tensors for step {step} " + f"(mx_source_id={self._mx_source_id})" + ) + return self._mx_source_id + + def publish_layer( + self, + layer_state_dict: dict[str, torch.Tensor], + layer_idx: int, + step: int, + worker_rank: int = 0, + ) -> str: + """Publish a single layer's weights to MX Server. + + Designed for PRIME-RL's layer-by-layer streaming pattern where + ``filter_state_dict_by_layers()`` yields one layer at a time. + + Layer tensors are registered with NIXL (overwriting previous + registration), and metadata is published to the MX Server. The + inference side accumulates all layers before loading. + + Args: + layer_state_dict: Parameter name -> tensor for this layer. + layer_idx: Layer index (-1 for non-layer weights like embeddings). + step: Current training step. + worker_rank: GPU rank of this worker. + + Returns: + The ``mx_source_id`` assigned by the server. + """ + if not self._initialized: + raise RuntimeError("Call initialize() before publish_layer()") + if self._publish_mode == "weights": + raise RuntimeError( + "publish_layer() and publish_weights() are mutually exclusive: " + "this publisher has already been used in 'weights' mode " + "(publish_weights was called previously). See publish_weights " + "for the full explanation." + ) + self._publish_mode = "layer" + + self._nixl.register_tensors(layer_state_dict) + metadata = self._nixl.nixl_metadata + descriptors = self._nixl.tensor_descriptors + + identity = self._build_identity(step) + identity.extra_parameters["layer_idx"] = str(layer_idx) + + worker_meta = p2p_pb2.WorkerMetadata( + worker_rank=worker_rank, + nixl_metadata=metadata, + tensors=self._build_tensor_protos(descriptors), + status=p2p_pb2.SOURCE_STATUS_INITIALIZING, + agent_name=self._agent_name, + ) + + self._mx_source_id = self._client.publish_metadata( + identity=identity, + worker=worker_meta, + worker_id=self._worker_id, + ) + logger.debug( + f"Published layer {layer_idx} ({len(layer_state_dict)} tensors) " + f"for step {step}" + ) + return self._mx_source_id + + def mark_ready(self, worker_rank: int = 0) -> bool: + """Signal that all layers/weights have been published and are ready. + + Inference workers filter on ``SOURCE_STATUS_READY`` when polling, + so this must be called after all publish calls for a given step. + """ + if self._mx_source_id is None: + raise RuntimeError("No weights published yet; call publish_weights() first") + + return self._client.update_status( + mx_source_id=self._mx_source_id, + worker_id=self._worker_id, + worker_rank=worker_rank, + status=p2p_pb2.SOURCE_STATUS_READY, + ) + + def shutdown(self) -> None: + """Release NIXL agent and close gRPC channel.""" + if self._nixl is not None: + self._nixl.shutdown() + self._nixl = None + if self._client is not None: + self._client.close() + self._client = None + self._initialized = False + logger.info(f"MxTrainingPublisher shut down: {self._agent_name}") diff --git a/modelexpress_client/python/modelexpress/vllm_weight_transfer.py b/modelexpress_client/python/modelexpress/vllm_weight_transfer.py new file mode 100644 index 00000000..aa6e29d7 --- /dev/null +++ b/modelexpress_client/python/modelexpress/vllm_weight_transfer.py @@ -0,0 +1,511 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""vLLM ``WeightTransferEngine`` adapter for ModelExpress + NIXL. + +This module is the **upstream-facing form** of all the Phase 2 / 3 / 4 +work landed across this branch and the prime-rl follow-up PRs. It +wraps the v2 fat clients (:class:`MxV2RefitReceiver` + +:class:`MxV2TrainingPublisher`) behind vLLM's native ``WeightTransferEngine`` +abstract base (introduced in the 2026-05-28 vLLM Native RL APIs blog), +so RL frameworks can pick it up via the standard four-phase lifecycle: + +:: + + from vllm import LLM + from vllm.config import WeightTransferConfig + import modelexpress.vllm_weight_transfer # noqa: F401 — registers "mx_nixl" + + llm = LLM(model="...", weight_transfer_config=WeightTransferConfig(backend="mx_nixl")) + llm.init_weight_transfer_engine(WeightTransferInitRequest( + init_info=MxInitInfo( + mx_server_url="modelexpress-server:8001", + model_name="Qwen/Qwen3-30B-A3B-Instruct-2507", + worker_rank=0, agent_name="vllm-inference-r0", device_id=0, + ) + )) + + # per training step: + llm.start_weight_update() + llm.update_weights(WeightTransferUpdateRequest( + update_info=MxUpdateInfo( + version=step, + compile_target_filter={"cutlass_fp8"}, + target_tp_layout=None, # matched-TP fast path; set for mixed-TP + ) + )) + llm.finish_weight_update() + +What's wrapped from each phase: + +* Phase 2 — heartbeat + freshest-per-rank dedup + same-rank-only filter + in the discovery layer +* Phase 3a — ``compile_target`` + ``compile_metadata`` tagging on the + publisher side (carried in v2 ``TensorDescriptorV2``) +* Phase 3b — ``compile_target_filter`` + ``required_compile_metadata`` + on the receiver side (refuses incompatible bytes before RDMA) +* Phase 4 — multi-source slice picker + stitching for mixed-TP / + mixed-EP between trainer and inference + +The engine handles two receive paths: + +1. **Matched TP/EP** (the common case today): single-source same-rank + pull via ``MxV2RefitReceiver.receive_from``. ``MxUpdateInfo.target_tp_layout`` + is ``None``. +2. **Mixed-TP** (e.g. trainer FSDP=4 ↔ inference TP=8): the multi-source + plan is computed via ``discover_v2_sources_for_slice`` and stitched + in ``receive_via_plan``. Caller sets ``MxUpdateInfo.target_tp_layout``. + +Design rationale, comparison to vLLM's built-in NCCL / IPC backends, and +comparison to Anyscale's RDT (PR #43375) plugin are in +``pensieve/RL/PrimeRL/10_mx_weight_transfer_engine_design.md``. +""" + +from __future__ import annotations + +import logging +import os +from dataclasses import dataclass, field +from typing import Any, Callable, Iterator + +import torch +from torch import Tensor + +from .nemo_rl_v2 import ( + MxV2RefitReceiver, + MxV2TrainingPublisher, + TargetTPLayout, + TrainerWorldLayout, +) +from .shape_descriptors import COMPILE_TARGET_HF_RAW + +logger = logging.getLogger("modelexpress.vllm_weight_transfer") + + +# ---------------------------------------------------------------------------- +# Init / Update / TrainerSend info dataclasses. +# +# We do NOT subclass vLLM's ``WeightTransferInitInfo`` / ``WeightTransferUpdateInfo`` +# at module import time, because we want the adapter to be import-safe +# in environments where vLLM is not installed (tests, the publisher +# side, benchmark harnesses). We expose plain dataclasses and let the +# registration step (at the bottom of this file) do the subclass +# substitution against vLLM's bases if available. +# ---------------------------------------------------------------------------- + + +@dataclass +class MxInitInfo: + """Initialization data for the MX backend. + + Args: + mx_server_url: gRPC URL of the ModelExpress server (e.g. + ``"modelexpress-server.kavin.svc.cluster.local:8001"``). + model_name: model identifier shared between trainer and inference. + Receivers filter discovery by this exact string. + worker_rank: this receiver's rank index. With ``MxV2RefitReceiver``'s + same-rank-only default, the receiver pulls from the trainer + rank with the same index. + agent_name: NIXL agent name for this receiver. Conventionally + ``f"vllm-inference-r{worker_rank}"``. + device_id: CUDA device this receiver writes into. + listen_port: optional NIXL listen port; ``None`` = auto-pick. + publish_self_as_replica: if True, after a successful receive + this engine calls ``publish_self_as_source`` so subsequent + receivers can pull from this rank instead of the trainer + (TensorHub-style pipeline replication / tree fan-out). + Recommended ``True`` for elastic deployments. + """ + + mx_server_url: str + model_name: str + worker_rank: int + agent_name: str + device_id: int = 0 + listen_port: int | None = None + publish_self_as_replica: bool = True + + +@dataclass +class MxUpdateInfo: + """Per-refit update data. + + Args: + version: monotonic step counter the trainer is at. Receiver pulls + sources with ``training_step >= version``. + target_tp_layout: receiver's local TP/EP slice descriptor (Phase 4). + ``None`` (default) → matched-TP fast path: single-source + same-rank pull via ``discover_v2_sources`` + ``receive_from``. + Set to a ``TargetTPLayout`` when trainer and inference have + different TP/EP layouts; engine then uses + ``discover_v2_sources_for_slice`` + ``receive_via_plan``. + compile_target_filter: optional whitelist of compile-target strings. + ``None`` = accept anything (back-compat). Set to e.g. + ``{"cutlass_fp8"}`` or ``{"cutlass_fp8", "hf_raw"}`` to refuse + sources whose tensors are tagged with a target outside this set. + required_compile_metadata: optional kv subset that every tensor's + ``compile_metadata`` must agree with. Useful for pinning block + sizes, scale layouts, kernel versions. + timeout_seconds: cap on per-receive RDMA wait. + same_rank_only: enforce same-rank trainer-to-inference peering + (required on GCP GB200 multi-NIC fabrics where rdma-0..3 + are separate L3 subnets). + dedup_freshest_per_rank: keep only the freshest published source + per ``worker_rank`` (the bug class our Phase 2 PR codifies). + """ + + version: int + target_tp_layout: TargetTPLayout | None = None + compile_target_filter: set[str] | frozenset[str] | None = None + required_compile_metadata: dict[str, Any] | None = None + timeout_seconds: float = 300.0 + same_rank_only: bool = True + dedup_freshest_per_rank: bool = True + + +@dataclass +class MxTrainerSendArgs: + """Optional trainer-side args for :meth:`MxWeightTransferEngine.trainer_send_weights`. + + Trainers that drive sends through this engine pass these args at each + publish to control how the bytes are tagged. The publisher itself is + long-lived and should be reused across steps; only ``version`` and + optionally ``compile_target`` / ``compile_metadata`` change per step. + """ + + publisher: MxV2TrainingPublisher # long-lived, heartbeat-started + version: int # the training step + compile_target: str = COMPILE_TARGET_HF_RAW + compile_metadata: dict[str, Any] | None = None + # Per-tensor MoE expert metadata. Map tensor name → expert axis and + # tuple of expert IDs this rank owns. Leave empty for non-expert tensors. + expert_axis_map: dict[str, int] = field(default_factory=dict) + owned_expert_ids: dict[str, tuple[int, ...]] = field(default_factory=dict) + + +# ---------------------------------------------------------------------------- +# The engine itself. +# +# We hold off on subclassing vLLM's ``WeightTransferEngine`` until the +# registration step (see bottom of file), so this module imports cleanly +# without vLLM. Tests can exercise the methods directly on this class. +# ---------------------------------------------------------------------------- + + +class MxWeightTransferEngine: + """ModelExpress + NIXL adapter for vLLM's WeightTransferEngine API. + + Receiver side wraps :class:`MxV2RefitReceiver` and exposes the four + capabilities prime-rl currently bolts on by hand: + + - heartbeat-aware rendezvous (Phase 2) + - ``compile_target`` / ``compile_metadata`` filtering (Phase 3a/3b) + - multi-source slice picker for mixed-TP (Phase 4) + - tree fan-out so newcomers pull from already-loaded peers, not + the trainer (TensorHub pipeline pattern, opt-in via + :attr:`MxInitInfo.publish_self_as_replica`) + + Trainer side wraps :class:`MxV2TrainingPublisher` via the optional + :meth:`trainer_send_weights` classmethod. Trainers can drive + publishes outside this engine — the method is a convenience for the + case where the trainer wants the engine to own the publish. + + Args: + init_info: optional pre-built init info, allowing the engine to + be constructed and initialized in one go. If omitted, the + caller must invoke :meth:`init_transfer_engine` before any + :meth:`receive_weights` call. + """ + + # vLLM's WeightTransferEngine declares these class attributes + # pointing at the request-info dataclasses. We populate them so the + # factory can find our types post-registration. + init_info_cls = MxInitInfo + update_info_cls = MxUpdateInfo + + def __init__(self, init_info: MxInitInfo | None = None) -> None: + self._receiver: MxV2RefitReceiver | None = None + self._init_info: MxInitInfo | None = None + if init_info is not None: + self.init_transfer_engine(init_info) + + # ------------------------------------------------------------------ + # Receiver-side API (the WeightTransferEngine contract). + # ------------------------------------------------------------------ + + def init_transfer_engine(self, init_info: MxInitInfo) -> None: + """Stand up the MX v2 receiver. + + NIXL register doesn't happen here — vLLM gives us the model's + param buffers only after the engine is initialized, and the + receiver's ``initialize()`` is what binds them. We defer to the + first :meth:`receive_weights` call so the engine can be + instantiated and configured before vLLM has built its workers. + """ + self._init_info = init_info + self._receiver = MxV2RefitReceiver( + agent_name=init_info.agent_name, + device_id=init_info.device_id, + mx_server_url=init_info.mx_server_url, + worker_rank=init_info.worker_rank, + listen_port=init_info.listen_port, + ) + logger.info( + "MxWeightTransferEngine init: agent=%s worker_rank=%s server=%s", + init_info.agent_name, + init_info.worker_rank, + init_info.mx_server_url, + ) + + def receive_weights( + self, + update_info: MxUpdateInfo, + load_weights: Callable[[list[tuple[str, Tensor]]], None], + ) -> None: + """Pull weights via NIXL RDMA + feed them through vLLM's load_weights. + + Mode selection happens on ``update_info.target_tp_layout``: + + - ``None`` → matched-TP fast path. Calls ``discover_v2_sources`` + + ``receive_from`` (single source, same-rank). Applies the + Phase 3 filters at discovery time so incompatible sources + are rejected before any RDMA cycles are spent. + - non-``None`` → mixed-TP / Phase-4 path. Calls + ``discover_v2_sources_for_slice`` to build the + ``SliceCoveragePlan``, then ``receive_via_plan`` to stitch + the slice from N publisher ranks. Applies the same Phase 3 + filters at discovery time. + + Args: + update_info: the per-step request descriptor. See + :class:`MxUpdateInfo` for fields. + load_weights: vLLM-provided callback. Each yielded tensor + is fed in as a single-element list so vLLM's + ``stacked_params_mapping`` can handle HF→fused name + remapping per call (matching the convention used by + NCCL / IPC / RDT backends). + """ + if self._receiver is None or self._init_info is None: + raise RuntimeError( + "MxWeightTransferEngine.init_transfer_engine() must be called first" + ) + + # Lazy initialize: the v2 receiver needs to be initialize()'d + # exactly once. We pass an empty model_tensors map because the + # current v0 of this adapter uses the scratch-buffer path (it + # writes into receiver-allocated buffers, then yields them for + # vLLM's load_weights to consume — matching RDT's pattern). + # When the upstream API gets a register_destinations hook + # (proposed extension, see design doc §5.1), this is where we + # pre-register vLLM's named_parameters for zero-copy receive. + if not self._receiver._initialized: + self._receiver.initialize(model_tensors={}) + + # ----- Phase 4 path: mixed-TP / multi-source ----- + if update_info.target_tp_layout is not None: + plan = self._receiver.discover_v2_sources_for_slice( + model_name=self._init_info.model_name, + target_layout=update_info.target_tp_layout, + min_version=update_info.version, + same_rank_only=update_info.same_rank_only, + compile_target_filter=update_info.compile_target_filter, + required_compile_metadata=update_info.required_compile_metadata, + ) + if not plan.fully_covered: + raise RuntimeError( + f"MxWeightTransferEngine: no covering source set for " + f"version={update_info.version}; missing={plan.missing}" + ) + for name, tensor in self._receiver.receive_via_plan( + plan, timeout_seconds=update_info.timeout_seconds + ): + load_weights([(name, tensor)]) + else: + # ----- Fast path: matched-TP / single-source ----- + candidates = self._receiver.discover_v2_sources( + model_name=self._init_info.model_name, + min_version=update_info.version, + same_rank_only=update_info.same_rank_only, + compile_target_filter=update_info.compile_target_filter, + required_compile_metadata=update_info.required_compile_metadata, + ) + chosen = self._receiver.pick_best_source(candidates) + if chosen is None: + raise RuntimeError( + f"MxWeightTransferEngine: no source matches filters for " + f"version={update_info.version}; " + f"compile_target_filter={update_info.compile_target_filter}, " + f"required_compile_metadata={update_info.required_compile_metadata}" + ) + # Scratch path: receiver allocates buffers matching the + # publisher's layout, NIXL writes into them, we yield them + # for the load_weights callback. This matches Anyscale's + # RDT plugin pattern and works without pre-registered + # model parameters — the common cold-start case for vLLM + # and the only sensible mode for the benchmark harness. + # Once vLLM exposes register_destinations, this can switch + # to the zero-copy `receive_from` path (design doc §5.1). + for name, tensor in self._receiver.receive_from_scratch( + chosen, timeout_seconds=update_info.timeout_seconds + ): + load_weights([(name, tensor)]) + + # Tree fan-out / pipeline replication: after a successful + # receive, optionally publish this rank's buffers so subsequent + # receivers (newcomers in an elastic deployment) can pull from + # us instead of the trainer. + if self._init_info.publish_self_as_replica: + try: + self._receiver.publish_self_as_source( + version=update_info.version, + model_name=self._init_info.model_name, + ) + except Exception as e: # noqa: BLE001 + # Pipeline replication is best-effort — it's an + # optimization for elastic deployments, not a + # correctness requirement. + logger.warning( + "MxWeightTransferEngine: publish_self_as_source failed: %s; " + "tree fan-out disabled for this cycle", + e, + ) + + # ------------------------------------------------------------------ + # Optional trainer-side API. + # ------------------------------------------------------------------ + + @classmethod + def trainer_send_weights( + cls, + iterator: Iterator[tuple[str, Tensor]], + trainer_args: MxTrainerSendArgs, + ) -> str: + """Publish all tensors yielded by ``iterator`` as one v2 publish. + + The trainer typically calls this once per training step. The + ``compile_target`` and ``compile_metadata`` on + :class:`MxTrainerSendArgs` propagate into every tensor's + :class:`TensorDescriptorV2`, which is the wire form that + receivers filter on via :attr:`MxUpdateInfo.compile_target_filter`. + + Args: + iterator: ``(name, tensor)`` pairs to publish. For DTensors, + each tensor MUST be the rank-local shard + (``.to_local()``) — never the gathered full tensor. + That's the whole point of the v2 rank-to-rank design. + trainer_args: see :class:`MxTrainerSendArgs`. + + Returns: + the ``mx_source_id`` (16-hex hash) assigned by the server. + """ + pub = trainer_args.publisher + for name, tensor in iterator: + is_expert = name in trainer_args.expert_axis_map + pub.add_tensor( + name=name, + tensor=tensor, + is_expert=is_expert, + expert_axis=trainer_args.expert_axis_map.get(name, 0), + owned_expert_ids=trainer_args.owned_expert_ids.get(name, ()), + compile_target=trainer_args.compile_target, + compile_metadata=trainer_args.compile_metadata, + ) + return pub.publish(version=trainer_args.version) + + # ------------------------------------------------------------------ + # Metrics surface — what benchmarks + dashboards read. + # ------------------------------------------------------------------ + + @property + def last_transfer_stats(self): + """The :class:`TransferStats` from the most recent + :meth:`receive_weights` call. ``None`` before the first call. + + For the Phase-4 multi-source path, this reflects the LAST + contributing source's stats; the full per-source history is + on ``self.transfer_history``. + """ + if self._receiver is None: + return None + return self._receiver._receiver.last_stats # MxV2RefitReceiver → MxRefitReceiver + + @property + def transfer_history(self): + """Per-call :class:`TransferStats` history across the engine's + lifetime. Each item corresponds to one underlying NIXL + ``receive_from_source`` invocation.""" + if self._receiver is None: + return [] + return self._receiver._receiver.history + + @property + def last_discovery_seconds(self) -> float: + """Wall time of the most recent ``discover_v2_sources`` call — + the control-plane round-trip latency. Distinct from data-plane + RDMA time.""" + if self._receiver is None: + return 0.0 + return self._receiver._last_discovery_seconds + + +# ---------------------------------------------------------------------------- +# Registration with vLLM's WeightTransferEngineFactory. +# +# We import vLLM lazily and try/except: if vLLM is not installed in this +# environment (publisher side, tests, harnesses), the module still loads +# and the class is usable directly — only the factory registration is +# skipped. The MX_WEIGHT_TRANSFER_AUTOREGISTER env var lets callers opt +# out of auto-registration even when vLLM IS installed (useful for +# environments where vLLM is present but the user wants a different +# backend name or doesn't want the side-effect). +# ---------------------------------------------------------------------------- + + +def _register_with_vllm() -> bool: + """Register ``MxWeightTransferEngine`` with vLLM's factory. + + Returns True on successful registration, False if vLLM isn't + available or the user opted out via env var. + """ + if os.environ.get("MX_WEIGHT_TRANSFER_AUTOREGISTER") == "0": + return False + + try: + from vllm.distributed.weight_transfer import WeightTransferEngineFactory + from vllm.distributed.weight_transfer.base import ( + WeightTransferEngine as _VllmWeightTransferEngineBase, + ) + except ImportError: + logger.debug( + "vLLM not available; MxWeightTransferEngine remains usable " + "directly but is not registered with WeightTransferEngineFactory" + ) + return False + + # Subclass to bind to vLLM's actual base. Without this, vLLM's + # factory may reject our engine as not-a-WeightTransferEngine. + class _MxEngineForVllm(_VllmWeightTransferEngineBase, MxWeightTransferEngine): + pass + + try: + WeightTransferEngineFactory.register_engine("mx_nixl", _MxEngineForVllm) + logger.info("Registered MxWeightTransferEngine as backend='mx_nixl'") + return True + except Exception as e: # noqa: BLE001 + logger.warning( + "Failed to register MxWeightTransferEngine with vLLM: %s", e + ) + return False + + +# Auto-register on import. Callers who don't want this can set +# MX_WEIGHT_TRANSFER_AUTOREGISTER=0 before importing this module. +_AUTOREGISTERED = _register_with_vllm() + + +__all__ = [ + "MxInitInfo", + "MxUpdateInfo", + "MxTrainerSendArgs", + "MxWeightTransferEngine", +] diff --git a/modelexpress_client/python/scripts/v2_dtensor_e2e_demo.py b/modelexpress_client/python/scripts/v2_dtensor_e2e_demo.py new file mode 100644 index 00000000..bd4ca9a7 --- /dev/null +++ b/modelexpress_client/python/scripts/v2_dtensor_e2e_demo.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python3 +"""End-to-end NIXL+MX v2 demo with REAL DTensors. + +Sister to v2_moe_e2e_demo.py, but exercises the codepath that the NemoRL +DTensorPolicyWorker.stream_weights_via_mx uses in production: real DTensors +on a torch.distributed mesh, ``tensor.to_local()`` on the publisher, +``MxV2TrainingPublisher.add_tensor`` overriding ``global_shape`` from the +DTensor view, registry round-trip via the synthetic ``__mx_v2_meta__`` +sidecar, and same-rank RDMA pulls. + +What this validates that v2_moe_e2e_demo.py does NOT: + * ``shape_descriptors.describe_tensor`` works on a real DTensor (not a + fake stand-in object). Shard axis, local shard range, and the global + shape inferred from ``tensor.shape × fsdp_world_size`` line up with + the DTensor's actual placement. + * The publisher's per-tensor ``global_shape`` override (set after + ``add_tensor``) survives the JSON round-trip and is observable by the + receiver via ``decode_registry``. + * ``tensor.to_local()`` is in fact what gets NIXL-registered (no allgather + happens). We assert the publisher's NIXL-registered tensor has the + SHARD dim equal to ``global_dim / fsdp_world_size``. + +Run inside any pod that has GPUs, NIXL, reachability to MX server, and +torch.distributed available. + + WORLD_SIZE=4 N_REFIT_CYCLES=2 python3 v2_dtensor_e2e_demo.py +""" +from __future__ import annotations + +import logging +import os +import sys +import time + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp +from torch.distributed.device_mesh import init_device_mesh +from torch.distributed.tensor import DTensor, distribute_tensor +from torch.distributed.tensor.placement_types import Shard + +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s") +log = logging.getLogger("v2-dt-demo") + +MX_URL = os.environ.get("MX_URL", "modelexpress-server.kavin.svc.cluster.local:8001") +MODEL_NAME = os.environ.get("MODEL_NAME", "v2-dtensor-demo/Qwen3MoE-stub") +WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "4")) +N_REFIT_CYCLES = int(os.environ.get("N_REFIT_CYCLES", "2")) + +# Tensor sizes: produce a recognizably-sharded weight per rank. +HIDDEN = int(os.environ.get("HIDDEN", "1024")) +INTER = int(os.environ.get("INTER", "2048")) + +# FORCE_RDMA=1 disables UCX's intra-node `cuda_ipc` fast path so the demo +# exercises the same `rc_mlx5` (or `cuda_copy` over RDMA NIC) descriptor-list +# validation path that real cross-node transfers do. Without this, intra-node +# loopback runs through `cuda_ipc` which silently tolerates malformed +# descriptor entries — e.g. the v2 `__mx_v2_meta__` sidecar (addr=0, size=0) +# bug that MX PR #295 (commit 53c69ec) fixed. Set FORCE_RDMA=1 on every +# pre-deploy run so cross-host descriptor-list bugs surface in loopback. +if os.environ.get("FORCE_RDMA") == "1": + os.environ["UCX_TLS"] = os.environ.get("UCX_TLS", "self,sm,rc_mlx5,cuda_copy,tcp") + log.info( + "FORCE_RDMA=1: UCX_TLS=%s (cuda_ipc disabled to exercise descriptor validation)", + os.environ["UCX_TLS"], + ) + + +def _setup_dist(rank: int, world_size: int) -> None: + os.environ.setdefault("MASTER_ADDR", "127.0.0.1") + os.environ.setdefault("MASTER_PORT", "29551") + dist.init_process_group( + backend="nccl", + rank=rank, + world_size=world_size, + device_id=torch.device(f"cuda:{rank}"), + ) + torch.cuda.set_device(rank) + + +def _fingerprint(rank: int, version: int) -> int: + """bf16-exact sentinel encoding (multiples of 8 are exact for |v| < 2**14).""" + return (rank + 1) * 8 + version * 64 + + +def trainer_publish_dt(rank: int, version: int, layout, mx_url: str, mesh): + """Publish a sharded DTensor via the v2 publisher. Returns the publisher + so the caller can keep its NIXL agent alive while inference pulls. + """ + from modelexpress import MxV2TrainingPublisher + + log.info(f"[trainer R{rank}] publishing v={version}") + + pub = MxV2TrainingPublisher( + agent_name=f"v2dt-trainer-r{rank}-v{version}", + device_id=rank, + mx_server_url=mx_url, + worker_rank=rank, + world_layout=layout, + heartbeat=False, + ) + pub.initialize(model_name=MODEL_NAME, dtype="bfloat16") + + sentinel = _fingerprint(rank, version) + + # Build a placeholder global tensor and distribute via FSDP-style Shard(0). + # We deliberately seed the local view AFTER distribute_tensor so each rank + # owns a recognizably-different sentinel (distribute_tensor only respects + # rank-0's source, so seeding before would give every rank the same data). + with torch.cuda.device(rank): + global_placeholder = torch.zeros( + (HIDDEN, INTER), dtype=torch.bfloat16, device=f"cuda:{rank}" + ) + sharded_dt: DTensor = distribute_tensor(global_placeholder, mesh, [Shard(0)]) + + local = sharded_dt.to_local() + assert local.shape[0] == HIDDEN // WORLD_SIZE, (local.shape, HIDDEN, WORLD_SIZE) + assert sharded_dt.shape[0] == HIDDEN, (sharded_dt.shape, HIDDEN) + + # Seed this rank's local shard with its own sentinel. + local.fill_(float(sentinel)) + # bf16 round-trip self-check: multiples of 8 are exact. + assert abs(local[0, 0].item() - sentinel) < 0.5, (local[0, 0].item(), sentinel) + + log.info( + f"[trainer R{rank}] DTensor: global={tuple(sharded_dt.shape)} " + f"local={tuple(local.shape)} sentinel={sentinel}" + ) + + # Add as a DTensor — describe_tensor reads .placements off the DTensor + # and now correctly handles `tensor.shape == global` semantics. + # global_shape, shard_axis, and local_shard_range are all computed + # from the DTensor view; no manual override needed. + pub.add_tensor(name="model.layers.0.qkv_proj.weight", tensor=sharded_dt) + + # NIXL-register the LOCAL shard, not the DTensor (which has no + # data_ptr()). This mirrors what the NemoRL DTensorPolicyWorker does + # after the tensor.to_local() call. + pub._registered_tensors["model.layers.0.qkv_proj.weight"] = local.contiguous() + + mx_source_id = pub.publish(version=version) + pub.mark_ready() + log.info( + f"[trainer R{rank}] published v={version} mx_source_id={mx_source_id}" + ) + return pub, local + + +def inference_receive_dt(rank: int, version: int, mx_url: str) -> bool: + """Pull our same-rank trainer's local shard, verify byte correctness AND + that the registry exposes the GLOBAL shape (un-sharded).""" + from modelexpress import MxV2RefitReceiver + + log.info(f"[inference R{rank}] starts; v>={version}") + rec = MxV2RefitReceiver( + agent_name=f"v2dt-inference-r{rank}-v{version}", + device_id=rank, + mx_server_url=mx_url, + worker_rank=rank, + ) + with torch.cuda.device(rank): + recv_local = torch.zeros(HIDDEN // WORLD_SIZE, INTER, + dtype=torch.bfloat16, device=f"cuda:{rank}") + rec.initialize(model_tensors={"model.layers.0.qkv_proj.weight": recv_local}) + + deadline = time.perf_counter() + 30.0 + candidates = [] + while time.perf_counter() < deadline: + candidates = rec.discover_v2_sources( + model_name=MODEL_NAME, + min_version=version, + same_rank_only=True, + include_replicas=False, + ) + if candidates: + break + time.sleep(0.5) + + if not candidates: + log.error(f"[inference R{rank}] no v2 source found") + return False + + chosen = rec.pick_best_source(candidates) + log.info( + f"[inference R{rank}] picked role={chosen.role} src_rank={chosen.worker_rank} " + f"v={chosen.ref.training_step}" + ) + + # KEY ASSERTION: the registry the trainer published should expose the + # GLOBAL shape, not the local shape. (The local shape is what NIXL + # actually transferred; the global shape is the un-sharded view.) + if chosen.registry is not None: + for td in chosen.registry["tensors"]: + if td.name == "model.layers.0.qkv_proj.weight": + log.info( + f"[inference R{rank}] registry: global={td.global_shape} " + f"placement={td.placement_kind} shard_axis={td.shard_axis} " + f"local_range={td.local_shard_range}" + ) + assert td.global_shape == (HIDDEN, INTER), (td.global_shape, HIDDEN, INTER) + assert td.placement_kind == "SHARD" + assert td.shard_axis == 0 + expected_lo = rank * (HIDDEN // WORLD_SIZE) + expected_hi = expected_lo + (HIDDEN // WORLD_SIZE) + assert td.local_shard_range == (expected_lo, expected_hi), ( + td.local_shard_range, expected_lo, expected_hi + ) + break + else: + log.warning(f"[inference R{rank}] tensor not in registry (sidecar may be missing)") + else: + log.warning(f"[inference R{rank}] registry missing on candidate (sidecar transport drop?)") + + bytes_received = 0 + t0 = time.perf_counter() + for name, tensor in rec.receive_from(chosen, timeout_seconds=60.0): + bytes_received += tensor.numel() * tensor.element_size() + log.info(f"[inference R{rank}] received '{name}' shape={tuple(tensor.shape)}") + elapsed = time.perf_counter() - t0 + bw_mbps = bytes_received / 1e6 / elapsed if elapsed > 0 else 0.0 + + expected_value = _fingerprint(rank, version) + actual_value = recv_local[0, 0].item() + # Stronger check: every cell of the received local shard should equal sentinel + # (the trainer filled the whole local view). + elem_match = bool(torch.allclose( + recv_local.float(), torch.full_like(recv_local, float(expected_value)).float(), atol=0.5 + )) + log.info( + f"[inference R{rank}] {bytes_received/1e6:.2f} MB in {elapsed*1000:.0f} ms " + f"({bw_mbps:.0f} MB/s); local[0,0]={actual_value:.0f} expected={expected_value} " + f"all_elem_match={elem_match}" + ) + + ok = abs(actual_value - expected_value) < 0.5 and elem_match + log.info(f"[inference R{rank}] correctness: {'OK' if ok else 'FAIL'}") + return ok + + +def per_rank_main(rank: int, return_dict): + from modelexpress import TrainerWorldLayout + + _setup_dist(rank, WORLD_SIZE) + mesh = init_device_mesh("cuda", (WORLD_SIZE,)) + layout = TrainerWorldLayout(fsdp_world_size=WORLD_SIZE) + + publishers = [] + all_ok = True + + for cycle in range(N_REFIT_CYCLES): + version = cycle + log.info(f"=== R{rank} cycle {cycle} (version={version}) ===") + pub, _local = trainer_publish_dt(rank, version, layout, MX_URL, mesh) + publishers.append(pub) + + dist.barrier() # ensure all trainers published before inference polls + time.sleep(1.0) + + ok = inference_receive_dt(rank, version, MX_URL) + all_ok = all_ok and ok + + dist.barrier() + time.sleep(1.0) + + for p in publishers: + try: + p.shutdown() + except Exception as e: + log.warning(f"R{rank} shutdown: {e}") + + return_dict[rank] = all_ok + log.info(f"=== R{rank} done; all_ok={all_ok} ===") + + dist.destroy_process_group() + + +def main(): + log.info(f"=== v2 DTensor E2E: WORLD_SIZE={WORLD_SIZE} HIDDEN={HIDDEN} INTER={INTER} ===") + log.info(f"MX_URL={MX_URL} MODEL_NAME={MODEL_NAME} N_REFIT_CYCLES={N_REFIT_CYCLES}") + + if torch.cuda.device_count() < WORLD_SIZE: + log.error(f"need {WORLD_SIZE} GPUs, got {torch.cuda.device_count()}") + sys.exit(2) + + mp.set_start_method("spawn", force=True) + manager = mp.Manager() + return_dict = manager.dict() + + procs = [] + for rank in range(WORLD_SIZE): + p = mp.Process(target=per_rank_main, args=(rank, return_dict)) + p.start() + procs.append(p) + for p in procs: + p.join() + + log.info(f"=== summary: {dict(return_dict)} ===") + if all(return_dict.values()) and len(return_dict) == WORLD_SIZE: + log.info("=== ALL RANKS OK ===") + sys.exit(0) + else: + log.error("=== SOME RANKS FAILED ===") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/modelexpress_client/python/scripts/v2_moe_e2e_demo.py b/modelexpress_client/python/scripts/v2_moe_e2e_demo.py new file mode 100644 index 00000000..48b01532 --- /dev/null +++ b/modelexpress_client/python/scripts/v2_moe_e2e_demo.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python3 +"""End-to-end NIXL+MX v2 demo for MoE-style weight refit. + +Spawns 4 processes via torch.multiprocessing, one per CUDA device. +Each process plays both 'trainer-rank-N' (publishes a fake MoE weight +shard for that rank's experts) and 'inference-rank-N' (discovers + pulls +its same-rank trainer's shard via NIXL RDMA). + +What this exercises end-to-end: + - MxV2TrainingPublisher: NIXL register, publish_metadata with v2 markers + (mx_v2=1, role=trainer, worker_rank=N, training_step=K), shape_registry + JSON, expert ownership IDs, agent_name fallback for old-server compat. + - MxV2RefitReceiver: discover_v2_sources with same_rank_only filter, + freshest-per-rank dedup, pick_best_source with MoE expert coverage, + receive_from (real RDMA WRITE). + - Heartbeat: HeartbeatThread keeps source alive on MX server. + - Two refit cycles to demonstrate version progression. + +Run inside the trainer pod where the MX server is reachable as +'modelexpress-server.kavin.svc.cluster.local:8001' and NIXL is configured. + +Expected output (key lines): + [trainer R0] published v=0 mx_source_id=... + [inference R0] picked source role=trainer src_rank=0 v=0 + [inference R0] received tensor 'experts.0.w1' shape=... + [trainer R0] published v=1 mx_source_id=... + [inference R0] picked freshest v=1 +""" +from __future__ import annotations + +import os +import sys +import time +import logging +import torch +import torch.multiprocessing as mp + +logging.basicConfig(level=logging.INFO, format="%(asctime)s [%(name)s] %(message)s") +log = logging.getLogger("v2-demo") + +MX_URL = os.environ.get("MX_URL", "modelexpress-server.kavin.svc.cluster.local:8001") +MODEL_NAME = os.environ.get("MODEL_NAME", "v2-demo/MoE-fake") +WORLD_SIZE = int(os.environ.get("WORLD_SIZE", "4")) +NUM_EXPERTS = int(os.environ.get("NUM_EXPERTS", "8")) # 8 experts per "layer" +HIDDEN = int(os.environ.get("HIDDEN", "256")) +N_REFIT_CYCLES = int(os.environ.get("N_REFIT_CYCLES", "2")) + +# FORCE_RDMA=1 disables UCX's intra-node `cuda_ipc` fast path so the demo +# exercises the same `rc_mlx5` (or `cuda_copy` over RDMA NIC) descriptor-list +# validation path that real cross-node transfers do. Without this, intra-node +# loopback runs through `cuda_ipc` which silently tolerates malformed +# descriptor entries — e.g. the v2 `__mx_v2_meta__` sidecar (addr=0, size=0) +# bug that MX PR #295 (commit 53c69ec) fixed. Set FORCE_RDMA=1 on every +# pre-deploy run so cross-host descriptor-list bugs surface in loopback. +if os.environ.get("FORCE_RDMA") == "1": + # rc_mlx5 = RDMA over RoCE/IB; cuda_copy = staged GPU↔NIC via host bounce. + # Both run UCX's strict prep_xfer_dlist validation. self+sm kept so the + # ZMQ/gRPC control plane still works. + os.environ["UCX_TLS"] = os.environ.get("UCX_TLS", "self,sm,rc_mlx5,cuda_copy,tcp") + log.info( + "FORCE_RDMA=1: UCX_TLS=%s (cuda_ipc disabled to exercise descriptor validation)", + os.environ["UCX_TLS"], + ) + + +def trainer_publish(rank: int, version: int, layout, mx_url: str): + """Run as the trainer side: publish a moe-flavored shard for our rank.""" + from modelexpress import MxV2TrainingPublisher + + # MoE expert layout: each rank owns NUM_EXPERTS / WORLD_SIZE experts. + chunk = NUM_EXPERTS // WORLD_SIZE + owned = set(range(rank * chunk, (rank + 1) * chunk)) + log.info(f"[trainer R{rank}] starts; owns experts {sorted(owned)}; v={version}") + + pub = MxV2TrainingPublisher( + agent_name=f"v2demo-trainer-r{rank}-v{version}", + device_id=rank, + mx_server_url=mx_url, + worker_rank=rank, + world_layout=layout, + heartbeat=False, # short-lived demo; skip heartbeat + ) + pub.initialize(model_name=MODEL_NAME, dtype="bfloat16") + + # Fake MoE expert tensor: leading axis is the expert dim (= owned chunk). + # Each rank's local shard holds (chunk, HIDDEN, HIDDEN). + # Use exact-in-bfloat16 sentinel values: multiples of 8 are exact for + # magnitudes up to 2^15. Encode sentinel as ((rank+1) * 8 + version * 64) + # so distinct (rank, version) pairs always have distinct values. + sentinel = (rank + 1) * 8 + version * 64 + with torch.cuda.device(rank): + moe_w = torch.randn(chunk, HIDDEN, HIDDEN, dtype=torch.bfloat16, device=f"cuda:{rank}") + moe_w[0, 0, 0] = float(sentinel) + # Plus a non-expert tensor (replicated across ranks). + ln_w = torch.ones(HIDDEN, dtype=torch.bfloat16, device=f"cuda:{rank}") + ln_w[0] = float(sentinel) + + pub.add_tensor( + name="model.layers.0.experts.weight", + tensor=moe_w, + is_expert=True, + expert_axis=0, + owned_expert_ids=owned, + ) + pub.add_tensor(name="model.layers.0.layer_norm.weight", tensor=ln_w) + + mx_source_id = pub.publish(version=version) + pub.mark_ready() + log.info( + f"[trainer R{rank}] published v={version} mx_source_id={mx_source_id} " + f"sentinel_target={sentinel} got={moe_w[0, 0, 0].item():.0f}" + ) + # Hand back the publisher so the worker can call shutdown after its peer reads. + return pub, moe_w, ln_w + + +def inference_receive(rank: int, version: int, mx_url: str): + """Run as the inference side: discover + pull our same-rank trainer.""" + from modelexpress import MxV2RefitReceiver + + chunk = NUM_EXPERTS // WORLD_SIZE + log.info(f"[inference R{rank}] starts; expects experts of size {chunk}; v={version}") + + rec = MxV2RefitReceiver( + agent_name=f"v2demo-inference-r{rank}-v{version}", + device_id=rank, + mx_server_url=mx_url, + worker_rank=rank, + ) + + # Pre-allocate receive buffers matching the trainer's shape. + with torch.cuda.device(rank): + recv_moe = torch.zeros( + chunk, HIDDEN, HIDDEN, dtype=torch.bfloat16, device=f"cuda:{rank}" + ) + recv_ln = torch.zeros(HIDDEN, dtype=torch.bfloat16, device=f"cuda:{rank}") + rec.initialize( + model_tensors={ + "model.layers.0.experts.weight": recv_moe, + "model.layers.0.layer_norm.weight": recv_ln, + } + ) + + # Discover same-rank source, with v2-only filter. Poll for up to 30 s + # to handle propagation delays. + deadline = time.perf_counter() + 30.0 + candidates = [] + while time.perf_counter() < deadline: + candidates = rec.discover_v2_sources( + model_name=MODEL_NAME, + min_version=version, + same_rank_only=True, + include_replicas=True, + ) + if candidates: + break + time.sleep(0.5) + + if not candidates: + log.error(f"[inference R{rank}] no v2 source found (timeout)") + return False + + chosen = rec.pick_best_source(candidates) + log.info( + f"[inference R{rank}] picked source role={chosen.role} " + f"src_rank={chosen.worker_rank} v={chosen.ref.training_step} " + f"updated_at={chosen.updated_at}" + ) + + bytes_received = 0 + t0 = time.perf_counter() + for name, tensor in rec.receive_from(chosen, timeout_seconds=60.0): + bytes_received += tensor.numel() * tensor.element_size() + log.info( + f"[inference R{rank}] received '{name}' shape={tuple(tensor.shape)} " + f"dtype={tensor.dtype}" + ) + elapsed = time.perf_counter() - t0 + bw_mbps = bytes_received / 1e6 / elapsed if elapsed > 0 else 0.0 + + # Verify fingerprints match. + expected_sentinel = (rank + 1) * 8 + version * 64 + actual_moe = recv_moe[0, 0, 0].item() + actual_ln = recv_ln[0].item() + log.info( + f"[inference R{rank}] {bytes_received/1e6:.2f} MB in {elapsed*1000:.0f} ms " + f"({bw_mbps:.0f} MB/s); moe[0,0,0]={actual_moe:.0f} ln[0]={actual_ln:.0f} " + f"expected={expected_sentinel}" + ) + ok = ( + abs(actual_moe - expected_sentinel) < 0.5 + and abs(actual_ln - expected_sentinel) < 0.5 + ) + log.info(f"[inference R{rank}] correctness: {'OK' if ok else 'FAIL'}") + + # Tree fan-out: republish self as inference_replica + rec.publish_self_as_source(version=version, model_name=MODEL_NAME) + + return ok + + +def per_rank_main(rank: int, return_dict): + """Entry point for each spawned process — plays both trainer & inference.""" + from modelexpress import TrainerWorldLayout + + layout = TrainerWorldLayout(fsdp_world_size=WORLD_SIZE, ep_world_size=WORLD_SIZE) + publishers = [] + all_ok = True + + for cycle in range(N_REFIT_CYCLES): + version = cycle # 0, 1, ... + log.info(f"=== R{rank} cycle {cycle} (version={version}) ===") + + pub, _moe, _ln = trainer_publish(rank, version, layout, MX_URL) + publishers.append(pub) + + # Tiny barrier via wallclock — give all trainers ~2s to publish so + # discover_v2_sources sees a coherent set. + time.sleep(2.0) + + ok = inference_receive(rank, version, MX_URL) + all_ok = all_ok and ok + + # Inter-cycle gap so version-N is observably newer than version-(N-1) + time.sleep(2.0) + + # Drop publishers (releases NIXL agents) + for p in publishers: + try: + p.shutdown() + except Exception as e: + log.warning(f"R{rank} shutdown: {e}") + + return_dict[rank] = all_ok + log.info(f"=== R{rank} done; all_ok={all_ok} ===") + + +def main(): + log.info(f"=== v2 MoE E2E demo: WORLD_SIZE={WORLD_SIZE} NUM_EXPERTS={NUM_EXPERTS} ===") + log.info(f"MX_URL={MX_URL} MODEL_NAME={MODEL_NAME} N_REFIT_CYCLES={N_REFIT_CYCLES}") + + if torch.cuda.device_count() < WORLD_SIZE: + log.error( + f"need {WORLD_SIZE} GPUs, only have {torch.cuda.device_count()}; aborting" + ) + sys.exit(2) + + mp.set_start_method("spawn", force=True) + manager = mp.Manager() + return_dict = manager.dict() + + procs = [] + for rank in range(WORLD_SIZE): + p = mp.Process(target=per_rank_main, args=(rank, return_dict)) + p.start() + procs.append(p) + for p in procs: + p.join() + + log.info(f"=== summary: {dict(return_dict)} ===") + if all(return_dict.values()) and len(return_dict) == WORLD_SIZE: + log.info("=== ALL RANKS OK ===") + sys.exit(0) + else: + log.error("=== SOME RANKS FAILED ===") + sys.exit(1) + + +if __name__ == "__main__": + main() diff --git a/modelexpress_client/python/tests/test_bench_elastic_scaling.py b/modelexpress_client/python/tests/test_bench_elastic_scaling.py new file mode 100644 index 00000000..04a2c9a6 --- /dev/null +++ b/modelexpress_client/python/tests/test_bench_elastic_scaling.py @@ -0,0 +1,235 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the elastic-scaling benchmark harness. + +We exercise the orchestrator's CPU smoke path + the result aggregation +logic without touching MX, NIXL, or torch.cuda. The "live" subprocess +path is tested by running it against a real MX server (out of scope +for the unit test suite — see ``bench_elastic_scaling.py --mode=live`` +for the integration smoke). +""" + +from __future__ import annotations + +import importlib.util +import json +import os +import sys +from argparse import Namespace +from pathlib import Path + +import pytest + + +_HERE = Path(__file__).resolve().parent +_BENCH = _HERE.parent / "benchmarks" / "bench_elastic_scaling.py" + + +@pytest.fixture(scope="module") +def bench(): + """Load the benchmark script as a module.""" + spec = importlib.util.spec_from_file_location("bench_es", _BENCH) + mod = importlib.util.module_from_spec(spec) + sys.modules["bench_es"] = mod + spec.loader.exec_module(mod) + return mod + + +def _args(**overrides): + """Build a Namespace with defaults that match the CLI.""" + defaults = dict( + scenario="elastic_scale", + mode="cpu", + mx_server_url="localhost:8001", + model_name="bench/m", + num_receivers=3, + num_tensors=4, + tensor_bytes=1024 * 1024, + steps=2, + step_interval=0.0, + join_interval=0.5, + trainer_warmup=0.0, + poll_interval=0.1, + cycle_timeout=5.0, + deadline=10.0, + trainer_compile_target="cutlass_fp8", + tmpdir="/tmp/mx_bench_test", + output=None, + verbose=False, + ) + defaults.update(overrides) + return Namespace(**defaults) + + +def test_cpu_smoke_produces_consistent_result(bench): + """CPU smoke mode runs end-to-end and produces a well-formed BenchResult.""" + result = bench.run_cpu_smoke(_args(num_receivers=2, steps=3)) + assert result.scenario == "elastic_scale_cpu_smoke" + assert result.trainer is not None + assert result.trainer.published_versions == [1, 2, 3] + assert len(result.receivers) == 2 + for r in result.receivers: + assert r.join_latency_seconds == 0.05 + assert len(r.cycles) == 3 + for c in r.cycles: + assert c.bytes_received == 4 * 1024 * 1024 + assert c.source_worker_rank == 0 + + +def test_receiver_bandwidth_math(bench): + """avg_bandwidth_gbps math: total_bits / total_rdma_seconds / 1e9.""" + r = bench.ReceiverResult( + receiver_id="r", worker_rank=0, started_at=0.0 + ) + r.cycles.append(bench.ReceiverCycleResult( + version=1, bytes_received=1_250_000_000, rdma_seconds=1.0 + )) + # 1.25 GB * 8 = 10 Gbits, over 1s → 10 Gbps + assert r.avg_bandwidth_gbps() == pytest.approx(10.0) + + +def test_trainer_egress_under_no_tree_equals_total(bench): + """Without tree fan-out, all receivers pull from the trainer (rank 0), + so trainer_egress_bytes should equal total_delivered.""" + result = bench.run_cpu_smoke(_args(num_receivers=4, steps=2)) + total = sum(r.total_bytes() for r in result.receivers) + assert result.trainer_egress_bytes() == total + + +def test_trainer_egress_under_tree_fanout_can_be_less(bench): + """If receivers report source_worker_rank != 0, those bytes are NOT + counted against the trainer's egress. This is how the harness + measures the tree-fan-out savings.""" + # Hand-build a result where 2 of the 3 receivers pulled from a replica + result = bench.BenchResult( + scenario="tree_fanout", + config={}, + trainer=bench.TrainerResult( + worker_id="t", mx_source_id="sid", started_at=0.0, + published_versions=[1], compile_target="cutlass_fp8", + total_published_bytes=1_000_000, + ), + receivers=[ + bench.ReceiverResult( + receiver_id=f"r{i}", worker_rank=0, started_at=0.0, + cycles=[bench.ReceiverCycleResult( + version=1, bytes_received=1_000_000, + source_worker_rank=(0 if i == 0 else 1), + )] + ) + for i in range(3) + ], + started_at=0.0, finished_at=1.0, + ) + assert result.trainer_egress_bytes() == 1_000_000 # only r0 went to trainer + total = sum(r.total_bytes() for r in result.receivers) + assert total == 3_000_000 + + +def test_compile_target_verdicts(bench): + """Compile-target scenario derivation: receivers with any successful + cycle are 'accepted'; receivers with only error cycles are 'rejected'.""" + result = bench.BenchResult( + scenario="compile_target", config={}, + trainer=bench.TrainerResult( + worker_id="t", mx_source_id="sid", started_at=0.0, + published_versions=[1], compile_target="cutlass_fp8", + total_published_bytes=1_000_000, + ), + receivers=[ + bench.ReceiverResult( + receiver_id="recv-match", worker_rank=0, started_at=0.0, + compile_target_filter=["cutlass_fp8"], + cycles=[bench.ReceiverCycleResult(version=1, bytes_received=1_000_000)], + ), + bench.ReceiverResult( + receiver_id="recv-mismatch", worker_rank=0, started_at=0.0, + compile_target_filter=["deep_gemm_fp8"], + cycles=[bench.ReceiverCycleResult( + version=1, error="no source matches filters" + )], + ), + bench.ReceiverResult( + receiver_id="recv-no-filter", worker_rank=0, started_at=0.0, + compile_target_filter=None, + cycles=[bench.ReceiverCycleResult(version=1, bytes_received=1_000_000)], + ), + ], + started_at=0.0, finished_at=1.0, + ) + derived = bench._scenario_derived(result) + assert derived["verdicts"] == { + "recv-match": "accepted", + "recv-mismatch": "rejected", + "recv-no-filter": "accepted", + } + + +def test_p99_p50_join_latency(bench): + """Elastic-scale percentile helpers.""" + assert bench._p([], 0.5) is None + assert bench._p([1.0], 0.5) == 1.0 + assert bench._p([1, 2, 3, 4, 5], 0.5) == 3 + assert bench._p([1, 2, 3, 4, 5], 0.99) == 5 + + +def test_summary_table_includes_fanout_factor_for_tree_scenario(bench): + """The human-readable summary should call out the fan-out factor on + the tree_fanout scenario only.""" + result = bench.BenchResult( + scenario="tree_fanout", config={}, + trainer=bench.TrainerResult( + worker_id="t", mx_source_id="sid", started_at=0.0, + published_versions=[1], compile_target="cutlass_fp8", + total_published_bytes=1_000_000, + ), + receivers=[ + bench.ReceiverResult( + receiver_id="r0", worker_rank=0, started_at=0.0, + cycles=[bench.ReceiverCycleResult( + version=1, bytes_received=1_000_000, source_worker_rank=0 + )] + ), + bench.ReceiverResult( + receiver_id="r1", worker_rank=0, started_at=0.0, + cycles=[bench.ReceiverCycleResult( + version=1, bytes_received=1_000_000, source_worker_rank=1 + )] + ), + ], + started_at=0.0, finished_at=1.0, + ) + table = result.to_summary_table() + assert "fanout_factor" in table + assert "Tree fan-out" in table + + # The elastic_scale variant should NOT mention fanout_factor + result.scenario = "elastic_scale" + table = result.to_summary_table() + assert "fanout_factor" not in table + + +def test_argparse_defaults_round_trip(bench): + """The CLI parser produces a Namespace the orchestrator can consume.""" + ns = bench._parse_args(["--scenario=compile_target", "--num-receivers=5"]) + assert ns.scenario == "compile_target" + assert ns.num_receivers == 5 + + +def test_main_cpu_mode_writes_output(bench, tmp_path): + """End-to-end CLI: --mode=cpu --output writes a JSON file with the + expected shape.""" + out = tmp_path / "result.json" + rc = bench.main([ + "--mode=cpu", + "--scenario=tree_fanout", + "--num-receivers=2", + "--steps=2", + "--output", str(out), + ]) + assert rc == 0 + data = json.loads(out.read_text()) + assert data["scenario"] == "tree_fanout_cpu_smoke" + assert data["derived"]["scenario_specific"]["fanout_factor"] is not None + assert len(data["receivers"]) == 2 diff --git a/modelexpress_client/python/tests/test_v2_shape_registry.py b/modelexpress_client/python/tests/test_v2_shape_registry.py new file mode 100644 index 00000000..2b1297b4 --- /dev/null +++ b/modelexpress_client/python/tests/test_v2_shape_registry.py @@ -0,0 +1,283 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for v2 shape descriptors (no GPU / no NIXL required).""" + +from __future__ import annotations + +import sys +import importlib.util +from pathlib import Path + +import pytest + + +# Direct-load shape_descriptors so we can run without the full modelexpress +# package being importable (the package init pulls in nixl_transfer which +# requires a CUDA build to import). +_HERE = Path(__file__).resolve().parent +_MOD_PATH = _HERE.parent / "modelexpress" / "shape_descriptors.py" + + +@pytest.fixture(scope="module") +def sd(): + spec = importlib.util.spec_from_file_location( + "modelexpress.shape_descriptors_for_test", _MOD_PATH + ) + mod = importlib.util.module_from_spec(spec) + sys.modules["modelexpress.shape_descriptors_for_test"] = mod + spec.loader.exec_module(mod) + return mod + + +def test_replicate_descriptor_round_trip(sd): + import torch + + t = torch.randn(8, 16, dtype=torch.bfloat16) + desc = sd.describe_tensor(name="lm_head.weight", tensor=t, rank=0, fsdp_world_size=1) + assert desc.placement_kind == sd.PLACEMENT_REPLICATE + assert desc.global_shape == (8, 16) + assert desc.local_shard_range is None + blob = sd.encode_registry([desc], version=1, trainer_world_layout="fsdp:1") + parsed = sd.decode_registry(blob) + assert parsed["tensors"][0].name == "lm_head.weight" + assert parsed["tensors"][0].placement_kind == sd.PLACEMENT_REPLICATE + + +def test_sharded_dtensor_local_range(sd): + import torch + from torch.distributed.tensor.placement_types import Shard + + class FakeDT: + def __init__(self, shape, dtype, placements): + self.shape = torch.Size(shape) + self.dtype = dtype + self.placements = placements + + # Simulating an FSDP shard: rank 2 of 4 holds rows [4, 6) + fake = FakeDT([2, 16], torch.bfloat16, [Shard(0)]) + desc = sd.describe_tensor( + name="model.layers.0.mlp.gate_proj.weight", + tensor=fake, + rank=2, + fsdp_world_size=4, + ) + assert desc.placement_kind == sd.PLACEMENT_SHARD + assert desc.shard_axis == 0 + assert desc.global_shape == (8, 16) + assert desc.local_shard_range == (4, 6) + + +def test_moe_expert_descriptor_in_registry(sd): + import torch + from torch.distributed.tensor.placement_types import Shard + + class FakeDT: + def __init__(self, shape, dtype, placements): + self.shape = torch.Size(shape) + self.dtype = dtype + self.placements = placements + + fake = FakeDT([24, 4096, 12288], torch.bfloat16, [Shard(0)]) + desc = sd.describe_tensor( + name="model.layers.5.mlp.experts.weight", + tensor=fake, + rank=2, + fsdp_world_size=8, + is_expert=True, + expert_axis=0, + owned_expert_ids={48, 49, 50, 51, 52, 53}, + ) + assert desc.is_expert + assert desc.global_shape == (192, 4096, 12288) + assert desc.local_shard_range == (48, 72) + assert set(desc.owned_expert_ids) == {48, 49, 50, 51, 52, 53} + + blob = sd.encode_registry([desc], version=99, trainer_world_layout="fsdp:8,ep:8") + parsed = sd.decode_registry(blob) + parsed_desc = parsed["tensors"][0] + assert parsed_desc.is_expert + assert set(parsed_desc.owned_expert_ids) == {48, 49, 50, 51, 52, 53} + assert parsed_desc.global_shape == (192, 4096, 12288) + + +def test_expert_owner_map_uniform(sd): + m = sd.even_expert_owner_map(num_experts=192, ep_world_size=8) + assert all(len(s) == 24 for s in m.values()) + assert sum(len(s) for s in m.values()) == 192 + # Coverage is exactly the union of [0..192). + flat = set().union(*m.values()) + assert flat == set(range(192)) + + +def test_expert_owner_map_rejects_uneven(sd): + with pytest.raises(ValueError, match="not divisible"): + sd.even_expert_owner_map(num_experts=190, ep_world_size=8) + + +def test_expert_set_codec_round_trip(sd): + es = {3, 1, 2, 5, 4, 5, 1} # duplicates collapse + encoded = sd.encode_expert_set(es) + assert encoded == "1,2,3,4,5" + assert sd.decode_expert_set(encoded) == {1, 2, 3, 4, 5} + + +def test_decode_expert_set_handles_empty_and_whitespace(sd): + assert sd.decode_expert_set("") == set() + assert sd.decode_expert_set(None) == set() + assert sd.decode_expert_set(" 1, 2 , ,3 ") == {1, 2, 3} + + +def test_registry_full_round_trip_multitensor(sd): + """Trainer-side encode → wire → receiver-side decode preserves everything.""" + import torch + from torch.distributed.tensor.placement_types import Shard + + class FakeDT: + def __init__(self, shape, dtype, placements): + self.shape = torch.Size(shape) + self.dtype = dtype + self.placements = placements + + descriptors = [ + sd.describe_tensor( + name="lm_head.weight", + tensor=torch.randn(2048, 4096, dtype=torch.bfloat16), + rank=0, + fsdp_world_size=1, + ), + sd.describe_tensor( + name="model.layers.0.mlp.gate_up_proj.weight", + tensor=FakeDT([2048, 4096], torch.bfloat16, [Shard(0)]), + rank=3, + fsdp_world_size=4, + ), + sd.describe_tensor( + name="model.layers.0.mlp.experts.weight", + tensor=FakeDT([24, 4096, 12288], torch.bfloat16, [Shard(0)]), + rank=3, + fsdp_world_size=8, + is_expert=True, + owned_expert_ids={72, 73, 74, 75, 76, 77}, + ), + ] + blob = sd.encode_registry( + descriptors, version=1234, trainer_world_layout="fsdp:8,ep:8" + ) + parsed = sd.decode_registry(blob) + assert parsed["version"] == 1234 + assert parsed["trainer_world_layout"] == "fsdp:8,ep:8" + assert len(parsed["tensors"]) == 3 + + by_name = {t.name: t for t in parsed["tensors"]} + # 1) lm_head: replicate + h = by_name["lm_head.weight"] + assert h.placement_kind == sd.PLACEMENT_REPLICATE + assert h.global_shape == (2048, 4096) + # 2) gate_up_proj: sharded + g = by_name["model.layers.0.mlp.gate_up_proj.weight"] + assert g.placement_kind == sd.PLACEMENT_SHARD + assert g.global_shape == (8192, 4096) # 2048 * 4 + assert g.local_shard_range == (6144, 8192) # rank 3 of 4 + # 3) MoE expert + e = by_name["model.layers.0.mlp.experts.weight"] + assert e.is_expert + assert e.global_shape == (192, 4096, 12288) + assert e.local_shard_range == (72, 96) + assert set(e.owned_expert_ids) == {72, 73, 74, 75, 76, 77} + + +# Phase 3a / 3b: compile_target + compile_metadata round-trip. + + +def test_compile_target_default_is_hf_raw(sd): + import torch + + t = torch.randn(8, 16, dtype=torch.bfloat16) + desc = sd.describe_tensor(name="lm_head.weight", tensor=t, rank=0, fsdp_world_size=1) + assert desc.compile_target == sd.COMPILE_TARGET_HF_RAW + assert desc.compile_metadata == {} + + +def test_compile_target_round_trip(sd): + import torch + + t = torch.randn(8, 16, dtype=torch.bfloat16) + desc = sd.describe_tensor( + name="model.layers.0.mlp.gate_proj.weight", + tensor=t, + rank=0, + fsdp_world_size=1, + compile_target=sd.COMPILE_TARGET_DEEPGEMM_FP8, + compile_metadata={"block_size": 128, "scale_layout": "K-major"}, + ) + assert desc.compile_target == sd.COMPILE_TARGET_DEEPGEMM_FP8 + assert desc.compile_metadata == {"block_size": 128, "scale_layout": "K-major"} + + blob = sd.encode_registry([desc], version=7, trainer_world_layout="fsdp:1") + parsed = sd.decode_registry(blob) + out = parsed["tensors"][0] + assert out.compile_target == sd.COMPILE_TARGET_DEEPGEMM_FP8 + assert out.compile_metadata == {"block_size": 128, "scale_layout": "K-major"} + + +def test_compile_target_omitted_from_wire_when_default(sd): + import json + import torch + + t = torch.randn(8, 16, dtype=torch.bfloat16) + desc = sd.describe_tensor(name="x", tensor=t, rank=0, fsdp_world_size=1) + blob = sd.encode_registry([desc], version=1, trainer_world_layout="fsdp:1") + obj = json.loads(blob) + assert "compile_target" not in obj["tensors"][0] + assert "compile_metadata" not in obj["tensors"][0] + + +def test_compile_target_matches_no_filter_is_accept(sd): + desc = sd.TensorDescriptorV2( + name="w", + global_shape=(4,), + dtype="bfloat16", + compile_target=sd.COMPILE_TARGET_CUTLASS_FP8, + ) + assert sd.compile_target_matches(desc, allowed_targets=None) + + +def test_compile_target_matches_whitelist(sd): + desc = sd.TensorDescriptorV2( + name="w", + global_shape=(4,), + dtype="bfloat16", + compile_target=sd.COMPILE_TARGET_DEEPGEMM_FP8, + ) + assert sd.compile_target_matches( + desc, allowed_targets={sd.COMPILE_TARGET_DEEPGEMM_FP8, sd.COMPILE_TARGET_HF_RAW} + ) + assert not sd.compile_target_matches( + desc, allowed_targets={sd.COMPILE_TARGET_CUTLASS_FP8} + ) + + +def test_compile_target_matches_required_metadata(sd): + desc = sd.TensorDescriptorV2( + name="w", + global_shape=(4,), + dtype="bfloat16", + compile_target=sd.COMPILE_TARGET_DEEPGEMM_FP8, + compile_metadata={"block_size": 128, "scale_layout": "K-major"}, + ) + assert sd.compile_target_matches( + desc, + allowed_targets={sd.COMPILE_TARGET_DEEPGEMM_FP8}, + required_metadata={"block_size": 128}, + ) + assert not sd.compile_target_matches( + desc, + allowed_targets={sd.COMPILE_TARGET_DEEPGEMM_FP8}, + required_metadata={"block_size": 256}, + ) + assert not sd.compile_target_matches( + desc, + allowed_targets={sd.COMPILE_TARGET_DEEPGEMM_FP8}, + required_metadata={"k_split": 4}, + ) diff --git a/modelexpress_client/python/tests/test_v2_source_picker.py b/modelexpress_client/python/tests/test_v2_source_picker.py new file mode 100644 index 00000000..7f8cd442 --- /dev/null +++ b/modelexpress_client/python/tests/test_v2_source_picker.py @@ -0,0 +1,1081 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Tests for v2 same-rank source filtering / freshest-per-rank dedup / tree +fan-out picker logic. + +Mocks out the underlying NIXL / gRPC layer so we can drive the V2 receiver's +`discover_v2_sources` + `pick_best_source` purely from Python. +""" + +from __future__ import annotations + +import importlib.util +import sys +import types +from dataclasses import dataclass +from pathlib import Path +from typing import Any +from unittest.mock import MagicMock + +import pytest + + +_HERE = Path(__file__).resolve().parent +_PKG_ROOT = _HERE.parent / "modelexpress" + + +def _load(modname: str, path: Path): + spec = importlib.util.spec_from_file_location(modname, path) + mod = importlib.util.module_from_spec(spec) + sys.modules[modname] = mod + spec.loader.exec_module(mod) + return mod + + +@pytest.fixture(scope="module") +def v2(): + """Load shape_descriptors and nemo_rl_v2 in isolation. + + nemo_rl_v2 normally imports MxRefitReceiver / MxTrainingPublisher / + HeartbeatThread at module top. We mock those out before exec'ing the + v2 module so its imports succeed without NIXL / gRPC. + """ + # Pre-create stub modules for the dependencies that nemo_rl_v2 imports. + pkg = types.ModuleType("modelexpress") + pkg.__path__ = [str(_PKG_ROOT)] # type: ignore[attr-defined] + sys.modules["modelexpress"] = pkg + + # p2p_pb2 stub: just the constants & message classes used. + p2p_pb2 = types.ModuleType("modelexpress.p2p_pb2") + p2p_pb2.SOURCE_STATUS_READY = 2 + p2p_pb2.SOURCE_STATUS_INITIALIZING = 1 + p2p_pb2.SOURCE_STATUS_STALE = 3 + p2p_pb2.MX_SOURCE_TYPE_WEIGHTS = 0 + p2p_pb2.BACKEND_FRAMEWORK_UNKNOWN = 0 + sys.modules["modelexpress.p2p_pb2"] = p2p_pb2 + + @dataclass + class _SourceIdentity: + model_name: str = "" + mx_source_type: int = 0 + backend_framework: int = 0 + tensor_parallel_size: int = 0 + pipeline_parallel_size: int = 0 + expert_parallel_size: int = 0 + dtype: str = "" + quantization: str = "" + + def __post_init__(self): + self.extra_parameters = {} + + @dataclass + class _WorkerMetadata: + worker_rank: int = 0 + nixl_metadata: bytes = b"" + tensors: list = None + status: int = 0 + agent_name: str = "" + + def __post_init__(self): + if self.tensors is None: + self.tensors = [] + + @dataclass + class _TensorDescriptor: + name: str = "" + addr: int = 0 + size: int = 0 + device_id: int = 0 + dtype: str = "" + + p2p_pb2.SourceIdentity = _SourceIdentity # type: ignore[attr-defined] + p2p_pb2.WorkerMetadata = _WorkerMetadata # type: ignore[attr-defined] + p2p_pb2.TensorDescriptor = _TensorDescriptor # type: ignore[attr-defined] + + # Heartbeat stub: no-op start/stop. + hb = types.ModuleType("modelexpress.heartbeat") + + class _HBStub: + def __init__(self, *a, **kw): + self.started = False + + def start(self): + self.started = True + + def stop(self): + self.started = False + + hb.HeartbeatThread = _HBStub + sys.modules["modelexpress.heartbeat"] = hb + + # MxRefitReceiver / MxTrainingPublisher stubs. + refit_mod = types.ModuleType("modelexpress.refit_receiver") + + @dataclass + class _SourceRef: + mx_source_id: str = "" + worker_id: str = "" + model_name: str = "" + worker_rank: int = 0 + training_step: int = 0 + + class _RefitStub: + def __init__(self, *a, **kw): + self._client = MagicMock() + self._nixl = MagicMock() + self._agent_name = kw.get("agent_name", "stub") + self._worker_id = "stub-worker" + # Tests inject scratch payloads via this dict: worker_id → {name: tensor}. + self._scratch_payloads: dict[str, dict[str, Any]] = {} + + def initialize(self, model_tensors=None): + pass + + def receive_weights(self, ref, timeout_seconds=300.0): + return iter([]) + + def receive_weights_scratch( + self, ref, timeout_seconds=300.0, tensor_shapes=None + ): + payload = self._scratch_payloads.get(ref.worker_id, {}) + for name, tensor in payload.items(): + yield name, tensor + + refit_mod.MxRefitReceiver = _RefitStub + refit_mod.SourceRef = _SourceRef + sys.modules["modelexpress.refit_receiver"] = refit_mod + + pub_mod = types.ModuleType("modelexpress.training_publisher") + + class _PubStub: + def __init__(self, *a, **kw): + self._client = None + self._nixl = None + self.mx_source_id = "abcd1234" + self.worker_id = "stub-pub-worker" + + def initialize(self, **kw): + pass + + def publish_weights(self, named_tensors, step, worker_rank): + return self.mx_source_id + + def mark_ready(self, worker_rank=0): + return True + + def shutdown(self): + pass + + def _build_identity(self, step): + ident = p2p_pb2.SourceIdentity() + return ident + + pub_mod.MxTrainingPublisher = _PubStub + sys.modules["modelexpress.training_publisher"] = pub_mod + + # types.TensorDescriptor stub + types_mod = types.ModuleType("modelexpress.types") + + @dataclass + class _TD: + name: str = "" + addr: int = 0 + size: int = 0 + device_id: int = 0 + dtype: str = "" + + types_mod.TensorDescriptor = _TD + sys.modules["modelexpress.types"] = types_mod + + # Now exec shape_descriptors + nemo_rl_v2 against this module space. + sd = _load("modelexpress.shape_descriptors", _PKG_ROOT / "shape_descriptors.py") + pkg.shape_descriptors = sd # type: ignore[attr-defined] + v2 = _load("modelexpress.nemo_rl_v2", _PKG_ROOT / "nemo_rl_v2.py") + return v2 + + +def _fake_instance(model_name, mx_source_id, worker_id): + return types.SimpleNamespace( + model_name=model_name, mx_source_id=mx_source_id, worker_id=worker_id + ) + + +def _fake_meta(role, worker_rank, training_step, updated_at, registry_blob=""): + """Build a fake get_metadata response with v2 metadata in extra_parameters.""" + + @dataclass + class _Meta: + found: bool = True + + def __post_init__(self): + from modelexpress.p2p_pb2 import SourceIdentity, WorkerMetadata + + self.identity = SourceIdentity(model_name="m") + self.identity.extra_parameters.update( + { + "mx_v2": "1", + "role": role, + "worker_rank": str(worker_rank), + "training_step": str(training_step), + "shape_registry": registry_blob, + } + ) + self.worker = WorkerMetadata() + # we tack updated_at as an attribute since the proto stub doesn't + # natively expose it + self.worker.updated_at = updated_at + + return _Meta() + + +def test_same_rank_filter_dedup_freshest(v2): + """Multiple sources at the same rank → keep only freshest by updated_at.""" + receiver = v2.MxV2RefitReceiver( + agent_name="test-recv", + device_id=0, + mx_server_url="fake:8001", + worker_rank=2, + ) + receiver.initialize() + # Inject 4 fake sources at MX: + # rank 0 trainer (irrelevant — different rank) + # rank 2 trainer, version 5, updated_at 100 + # rank 2 trainer, version 5, updated_at 200 (FRESHER, should win) + # rank 2 inference_replica, version 5, updated_at 50 (excluded, replicas not preferred over trainer) + response = MagicMock() + response.instances = [ + _fake_instance("m", "s0", "w_r0"), + _fake_instance("m", "s2", "w_r2_old"), + _fake_instance("m", "s2", "w_r2_new"), + _fake_instance("m", "s2", "w_r2_replica"), + ] + metas = { + "w_r0": _fake_meta("trainer", 0, 5, 1000), + "w_r2_old": _fake_meta("trainer", 2, 5, 100), + "w_r2_new": _fake_meta("trainer", 2, 5, 200), + "w_r2_replica": _fake_meta("inference_replica", 2, 5, 50), + } + receiver._receiver._client.list_sources.return_value = response + receiver._receiver._client.get_metadata = lambda mx_source_id, worker_id: metas[ + worker_id + ] + + candidates = receiver.discover_v2_sources( + model_name="m", min_version=0, same_rank_only=True + ) + # rank 0 is filtered out (same_rank_only=True; receiver is rank 2) + assert all(c.worker_rank == 2 for c in candidates), candidates + # All 3 remaining (2 trainers + 1 replica) are returned, but trainer comes first + assert len(candidates) == 3 + assert candidates[0].role == "trainer" + # Among trainers, freshest first + assert candidates[0].ref.worker_id == "w_r2_new" + # Replica comes last + assert candidates[-1].role == "inference_replica" + + +def test_min_version_filter(v2): + """Sources whose version is below min_version are excluded.""" + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=0 + ) + receiver.initialize() + response = MagicMock() + response.instances = [ + _fake_instance("m", "s", "w_old"), + _fake_instance("m", "s", "w_cur"), + _fake_instance("m", "s", "w_new"), + ] + metas = { + "w_old": _fake_meta("trainer", 0, 1, 100), + "w_cur": _fake_meta("trainer", 0, 5, 200), + "w_new": _fake_meta("trainer", 0, 7, 300), + } + receiver._receiver._client.list_sources.return_value = response + receiver._receiver._client.get_metadata = lambda mx_source_id, worker_id: metas[ + worker_id + ] + + cands = receiver.discover_v2_sources(model_name="m", min_version=5) + versions = sorted(c.ref.training_step for c in cands) + assert versions == [5, 7] + + +def test_non_v2_sources_ignored(v2): + """Sources lacking ``mx_v2`` marker are ignored entirely.""" + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=0 + ) + receiver.initialize() + response = MagicMock() + response.instances = [ + _fake_instance("m", "s", "v2_worker"), + _fake_instance("m", "s", "v1_worker"), + ] + + @dataclass + class _MetaV1: + found: bool = True + + def __post_init__(self): + from modelexpress.p2p_pb2 import SourceIdentity, WorkerMetadata + + self.identity = SourceIdentity(model_name="m") + # No mx_v2 marker → v1 source + self.identity.extra_parameters.update( + {"role": "trainer", "training_step": "1"} + ) + self.worker = WorkerMetadata() + self.worker.updated_at = 999 + + metas = { + "v2_worker": _fake_meta("trainer", 0, 1, 100), + "v1_worker": _MetaV1(), + } + receiver._receiver._client.list_sources.return_value = response + receiver._receiver._client.get_metadata = lambda mx_source_id, worker_id: metas[ + worker_id + ] + + cands = receiver.discover_v2_sources(model_name="m") + assert len(cands) == 1 + assert cands[0].ref.worker_id == "v2_worker" + + +def test_pick_best_with_expert_filter(v2): + """When MoE expert filter is set, candidate must own all needed experts.""" + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=2 + ) + # Build candidates manually to test pick_best_source in isolation. + candidates = [ + v2.V2SourceCandidate( + ref=type( + "Ref", + (), + { + "mx_source_id": "s", + "worker_id": "replica_partial", + "model_name": "m", + "worker_rank": 2, + "training_step": 5, + }, + )(), + role="inference_replica", + worker_rank=2, + registry=None, + owned_experts_per_layer={5: {48, 49, 50}}, # only 3 of needed 6 + updated_at=200, + ), + v2.V2SourceCandidate( + ref=type( + "Ref", + (), + { + "mx_source_id": "s", + "worker_id": "replica_full", + "model_name": "m", + "worker_rank": 2, + "training_step": 5, + }, + )(), + role="inference_replica", + worker_rank=2, + registry=None, + owned_experts_per_layer={5: {48, 49, 50, 51, 52, 53, 54, 55}}, + updated_at=100, # older but covers all needed + ), + ] + needed = {5: {48, 49, 50, 51, 52, 53}} + chosen = receiver.pick_best_source( + candidates, needed_experts_per_layer=needed + ) + assert chosen is not None + assert chosen.ref.worker_id == "replica_full" + + +def test_pick_best_falls_back_to_trainer(v2): + """Trainer always covers all experts (its registry is authoritative).""" + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=2 + ) + candidates = [ + v2.V2SourceCandidate( + ref=type( + "Ref", + (), + { + "mx_source_id": "s", + "worker_id": "trainer-2", + "model_name": "m", + "worker_rank": 2, + "training_step": 5, + }, + )(), + role="trainer", + worker_rank=2, + registry={"version": 5, "tensors": []}, + owned_experts_per_layer={}, # not populated for trainers + updated_at=300, + ), + ] + chosen = receiver.pick_best_source( + candidates, needed_experts_per_layer={5: {0, 1, 2, 3}} + ) + assert chosen is not None + assert chosen.ref.worker_id == "trainer-2" + + +def test_world_layout_round_trip(v2): + layout = v2.TrainerWorldLayout(fsdp_world_size=4, ep_world_size=8) + encoded = layout.encode() + assert encoded == "fsdp:4,tp:1,pp:1,ep:8" + rt = v2.TrainerWorldLayout.decode(encoded) + assert rt == layout + + +def test_agent_name_fallback_when_identity_missing(v2): + """Older servers don't return SourceIdentity in GetMetadataResponse; + the v2 receiver must fall back to parsing the v2 marker from + WorkerMetadata.agent_name.""" + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=2 + ) + receiver.initialize() + + response = MagicMock() + response.instances = [ + _fake_instance("m", "s", "v2_via_agent_name"), + ] + + class _MetaNoIdentity: + """Mimics an old-server GetMetadataResponse: no `identity` attribute.""" + + found = True + + def __init__(self): + from modelexpress.p2p_pb2 import WorkerMetadata + + self.worker = WorkerMetadata() + self.worker.updated_at = 12345 + # The publisher writes v2 markers into agent_name as a fallback + self.worker.agent_name = ( + "mx_v2|trainer|rank=2|version=42|orig=nemo-rl-trainer-r2" + ) + + receiver._receiver._client.list_sources.return_value = response + receiver._receiver._client.get_metadata = lambda mx_source_id, worker_id: ( + _MetaNoIdentity() + ) + + candidates = receiver.discover_v2_sources(model_name="m", min_version=0) + assert len(candidates) == 1 + cand = candidates[0] + assert cand.role == "trainer" + assert cand.worker_rank == 2 + assert cand.ref.training_step == 42 + assert cand.updated_at == 12345 + + +# ---------------------------------------------------------------------------- +# Phase 3b — compile_target_filter on discover_v2_sources +# ---------------------------------------------------------------------------- + + +def _registry_blob(v2, tensors): + """Helper: encode a registry from a list of TensorDescriptorV2 dicts.""" + sd = sys.modules["modelexpress.shape_descriptors"] + descriptors = [ + sd.TensorDescriptorV2( + name=t["name"], + global_shape=t.get("global_shape", (8, 16)), + dtype=t.get("dtype", "bfloat16"), + placement_kind=t.get("placement_kind", sd.PLACEMENT_REPLICATE), + shard_axis=t.get("shard_axis", 0), + local_shard_range=t.get("local_shard_range"), + compile_target=t.get("compile_target", sd.COMPILE_TARGET_HF_RAW), + compile_metadata=t.get("compile_metadata", {}), + ) + for t in tensors + ] + return sd.encode_registry(descriptors, version=1, trainer_world_layout="fsdp:1") + + +def _set_two_compile_sources(v2, receiver, hf_blob, fp8_blob): + response = MagicMock() + response.instances = [ + _fake_instance("m", "s", "trainer_hf"), + _fake_instance("m", "s", "trainer_fp8"), + ] + metas = { + "trainer_hf": _fake_meta("trainer", 0, 7, 200, registry_blob=hf_blob), + "trainer_fp8": _fake_meta("trainer", 0, 7, 300, registry_blob=fp8_blob), + } + receiver._receiver._client.list_sources.return_value = response + receiver._receiver._client.get_metadata = lambda mx_source_id, worker_id: metas[worker_id] + + +def test_compile_target_filter_accepts_only_matching(v2): + sd = sys.modules["modelexpress.shape_descriptors"] + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=0 + ) + receiver.initialize() + + hf_blob = _registry_blob(v2, [{"name": "w"}]) + fp8_blob = _registry_blob( + v2, + [{"name": "w", "compile_target": sd.COMPILE_TARGET_DEEPGEMM_FP8}], + ) + _set_two_compile_sources(v2, receiver, hf_blob, fp8_blob) + + only_hf = receiver.discover_v2_sources( + model_name="m", + min_version=0, + compile_target_filter={sd.COMPILE_TARGET_HF_RAW}, + ) + assert [c.ref.worker_id for c in only_hf] == ["trainer_hf"] + + only_fp8 = receiver.discover_v2_sources( + model_name="m", + min_version=0, + compile_target_filter={sd.COMPILE_TARGET_DEEPGEMM_FP8}, + ) + assert [c.ref.worker_id for c in only_fp8] == ["trainer_fp8"] + + both = receiver.discover_v2_sources( + model_name="m", + min_version=0, + compile_target_filter={ + sd.COMPILE_TARGET_HF_RAW, + sd.COMPILE_TARGET_DEEPGEMM_FP8, + }, + ) + assert {c.ref.worker_id for c in both} == {"trainer_hf", "trainer_fp8"} + + +def test_compile_target_filter_unset_admits_all(v2): + sd = sys.modules["modelexpress.shape_descriptors"] + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=0 + ) + receiver.initialize() + + hf_blob = _registry_blob(v2, [{"name": "w"}]) + fp8_blob = _registry_blob( + v2, + [{"name": "w", "compile_target": sd.COMPILE_TARGET_DEEPGEMM_FP8}], + ) + _set_two_compile_sources(v2, receiver, hf_blob, fp8_blob) + + out = receiver.discover_v2_sources(model_name="m", min_version=0) + assert {c.ref.worker_id for c in out} == {"trainer_hf", "trainer_fp8"} + + +def test_compile_target_filter_rejects_when_no_registry(v2): + """If the candidate has no registry but caller wants compile filtering, + we MUST reject (we can't certify the bytes blindly).""" + sd = sys.modules["modelexpress.shape_descriptors"] + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=0 + ) + receiver.initialize() + + response = MagicMock() + response.instances = [_fake_instance("m", "s", "no_registry")] + metas = { + "no_registry": _fake_meta("trainer", 0, 1, 100, registry_blob=""), + } + receiver._receiver._client.list_sources.return_value = response + receiver._receiver._client.get_metadata = lambda mx_source_id, worker_id: metas[worker_id] + + # With no filter, candidate is admitted (back-compat). + assert len(receiver.discover_v2_sources(model_name="m")) == 1 + # With a filter, candidate is rejected (no registry → unknowable target). + filtered = receiver.discover_v2_sources( + model_name="m", + compile_target_filter={sd.COMPILE_TARGET_HF_RAW}, + ) + assert filtered == [] + + +def test_compile_target_filter_required_metadata(v2): + sd = sys.modules["modelexpress.shape_descriptors"] + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=0 + ) + receiver.initialize() + + blob128 = _registry_blob( + v2, + [{ + "name": "w", + "compile_target": sd.COMPILE_TARGET_DEEPGEMM_FP8, + "compile_metadata": {"block_size": 128}, + }], + ) + blob256 = _registry_blob( + v2, + [{ + "name": "w", + "compile_target": sd.COMPILE_TARGET_DEEPGEMM_FP8, + "compile_metadata": {"block_size": 256}, + }], + ) + response = MagicMock() + response.instances = [ + _fake_instance("m", "s", "blk128"), + _fake_instance("m", "s", "blk256"), + ] + metas = { + "blk128": _fake_meta("trainer", 0, 1, 100, registry_blob=blob128), + "blk256": _fake_meta("trainer", 0, 1, 200, registry_blob=blob256), + } + receiver._receiver._client.list_sources.return_value = response + receiver._receiver._client.get_metadata = lambda mx_source_id, worker_id: metas[worker_id] + + keep_128 = receiver.discover_v2_sources( + model_name="m", + compile_target_filter={sd.COMPILE_TARGET_DEEPGEMM_FP8}, + required_compile_metadata={"block_size": 128}, + ) + assert [c.ref.worker_id for c in keep_128] == ["blk128"] + + +# ---------------------------------------------------------------------------- +# Phase 4 — multi-source slice discovery (discover_v2_sources_for_slice) +# ---------------------------------------------------------------------------- + + +def _build_two_trainers_tp4_mixed_tp8(v2): + """Two trainers at TP=4: rank 0 holds rows [0,2048), rank 1 holds [2048,4096) + on a tensor of global axis-0 extent 4096. Receiver at TP=8 rank N wants + rows [N*512, (N+1)*512).""" + sd = sys.modules["modelexpress.shape_descriptors"] + blob_r0 = _registry_blob( + v2, + [{ + "name": "w", + "global_shape": (4096, 1024), + "placement_kind": sd.PLACEMENT_SHARD, + "shard_axis": 0, + "local_shard_range": (0, 2048), + }], + ) + blob_r1 = _registry_blob( + v2, + [{ + "name": "w", + "global_shape": (4096, 1024), + "placement_kind": sd.PLACEMENT_SHARD, + "shard_axis": 0, + "local_shard_range": (2048, 4096), + }], + ) + return blob_r0, blob_r1 + + +def _set_two_trainers(v2, receiver, blob_r0, blob_r1): + response = MagicMock() + response.instances = [ + _fake_instance("m", "s", "trainer_r0"), + _fake_instance("m", "s", "trainer_r1"), + ] + metas = { + "trainer_r0": _fake_meta("trainer", 0, 7, 200, registry_blob=blob_r0), + "trainer_r1": _fake_meta("trainer", 1, 7, 200, registry_blob=blob_r1), + } + receiver._receiver._client.list_sources.return_value = response + receiver._receiver._client.get_metadata = lambda mx_source_id, worker_id: metas[worker_id] + + +def test_phase4_slice_within_one_trainer_shard(v2): + """Receiver TP=8 rank=1 wants rows [512, 1024) — fully inside trainer rank 0.""" + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=1 + ) + receiver.initialize() + blob_r0, blob_r1 = _build_two_trainers_tp4_mixed_tp8(v2) + _set_two_trainers(v2, receiver, blob_r0, blob_r1) + + plan = receiver.discover_v2_sources_for_slice( + model_name="m", + target_layout=v2.TargetTPLayout(world_size=8, rank=1, shard_axis=0), + same_rank_only=False, + ) + assert plan.fully_covered + contributions = plan.per_tensor_sources["w"] + assert len(contributions) == 1 + src = contributions[0] + assert src.candidate.ref.worker_id == "trainer_r0" + assert src.src_range == (512, 1024) + assert src.dst_range == (0, 512) + + +def test_phase4_slice_spans_two_trainer_shards(v2): + """Receiver TP=2 rank=0 wants [0, 2048) — exactly trainer rank 0 alone. + Receiver TP=2 rank=1 wants [2048, 4096) — exactly trainer rank 1 alone. + Now flip: receiver TP=4 rank=1 wants [1024, 2048) — split case stays inside + trainer rank 0. Real cross-shard case: receiver TP=2 wants exact halves + (use a receiver that explicitly straddles the boundary).""" + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=0 + ) + receiver.initialize() + blob_r0, blob_r1 = _build_two_trainers_tp4_mixed_tp8(v2) + _set_two_trainers(v2, receiver, blob_r0, blob_r1) + + # target_range=(1500, 2500) straddles the trainer rank 0/1 boundary at 2048. + plan = receiver.discover_v2_sources_for_slice( + model_name="m", + target_layout=v2.TargetTPLayout( + world_size=1, rank=0, shard_axis=0, target_range=(1500, 2500) + ), + same_rank_only=False, + ) + assert plan.fully_covered + contributions = plan.per_tensor_sources["w"] + assert len(contributions) == 2 + # First contribution from trainer rank 0: [1500, 2048) in src → [0, 548) in dst + first = contributions[0] + assert first.candidate.ref.worker_id == "trainer_r0" + assert first.src_range == (1500, 2048) + assert first.dst_range == (0, 548) + # Second contribution from trainer rank 1: [0, 452) in src → [548, 1000) in dst + second = contributions[1] + assert second.candidate.ref.worker_id == "trainer_r1" + assert second.src_range == (0, 452) + assert second.dst_range == (548, 1000) + + +def test_phase4_replicate_picks_one_candidate(v2): + """REPLICATE tensor: any candidate works; planner picks the freshest one.""" + sd = sys.modules["modelexpress.shape_descriptors"] + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=0 + ) + receiver.initialize() + + blob = _registry_blob( + v2, + [{"name": "lm_head.weight", "global_shape": (1024, 4096), "placement_kind": sd.PLACEMENT_REPLICATE}], + ) + response = MagicMock() + response.instances = [ + _fake_instance("m", "s", "trainer_a"), + _fake_instance("m", "s", "trainer_b"), + ] + metas = { + "trainer_a": _fake_meta("trainer", 0, 7, 100, registry_blob=blob), + "trainer_b": _fake_meta("trainer", 1, 7, 300, registry_blob=blob), + } + receiver._receiver._client.list_sources.return_value = response + receiver._receiver._client.get_metadata = lambda mx_source_id, worker_id: metas[worker_id] + + plan = receiver.discover_v2_sources_for_slice( + model_name="m", + target_layout=v2.TargetTPLayout(world_size=2, rank=0, shard_axis=0), + same_rank_only=False, + ) + assert plan.fully_covered + contribs = plan.per_tensor_sources["lm_head.weight"] + assert len(contribs) == 1 + # Freshest trainer (updated_at=300) preferred; either trainer is correct, but + # the picker sorts by freshness so trainer_b wins. + assert contribs[0].candidate.ref.worker_id == "trainer_b" + + +def test_phase4_coverage_gap_is_missing(v2): + """If trainers don't fully cover the receiver's slice, plan.missing fires.""" + sd = sys.modules["modelexpress.shape_descriptors"] + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=0 + ) + receiver.initialize() + # Single trainer with [0, 1024); receiver wants [0, 4096). + blob = _registry_blob( + v2, + [{ + "name": "w", + "global_shape": (4096,), + "placement_kind": sd.PLACEMENT_SHARD, + "shard_axis": 0, + "local_shard_range": (0, 1024), + }], + ) + response = MagicMock() + response.instances = [_fake_instance("m", "s", "only_one")] + metas = {"only_one": _fake_meta("trainer", 0, 7, 200, registry_blob=blob)} + receiver._receiver._client.list_sources.return_value = response + receiver._receiver._client.get_metadata = lambda mx_source_id, worker_id: metas[worker_id] + + plan = receiver.discover_v2_sources_for_slice( + model_name="m", + target_layout=v2.TargetTPLayout( + world_size=1, rank=0, shard_axis=0, target_range=(0, 4096) + ), + same_rank_only=False, + ) + assert not plan.fully_covered + assert plan.missing + assert "coverage gap" in plan.missing[0] + + +def test_phase4_shard_axis_mismatch_is_missing(v2): + """Publisher shards on axis 0 but receiver wants axis 1: caller's problem.""" + sd = sys.modules["modelexpress.shape_descriptors"] + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=0 + ) + receiver.initialize() + blob = _registry_blob( + v2, + [{ + "name": "w", + "global_shape": (4096, 1024), + "placement_kind": sd.PLACEMENT_SHARD, + "shard_axis": 0, + "local_shard_range": (0, 4096), + }], + ) + response = MagicMock() + response.instances = [_fake_instance("m", "s", "trainer")] + metas = {"trainer": _fake_meta("trainer", 0, 7, 200, registry_blob=blob)} + receiver._receiver._client.list_sources.return_value = response + receiver._receiver._client.get_metadata = lambda mx_source_id, worker_id: metas[worker_id] + + plan = receiver.discover_v2_sources_for_slice( + model_name="m", + target_layout=v2.TargetTPLayout( + world_size=2, rank=0, shard_axis=1, target_range=(0, 512) + ), + same_rank_only=False, + ) + assert not plan.fully_covered + assert any("shard_axis mismatch" in m for m in plan.missing) + + +def test_phase4_receive_via_plan_stitches_two_sources(v2): + """End-to-end: planner + receive_via_plan correctly stitches two shards.""" + import torch + + sd = sys.modules["modelexpress.shape_descriptors"] + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=0 + ) + receiver.initialize() + blob_r0, blob_r1 = _build_two_trainers_tp4_mixed_tp8(v2) + _set_two_trainers(v2, receiver, blob_r0, blob_r1) + + # Build fake scratch tensors that match what the publishers would expose: + # trainer r0 owns [0, 2048), trainer r1 owns [2048, 4096). Each is a + # (local_extent, 1024) bf16 tensor (we'd actually be returning floats here + # so the in-process slice/cat math is observable). + r0_buf = torch.arange(0, 2048).repeat_interleave(1024).view(2048, 1024).float() + r1_buf = torch.arange(2048, 4096).repeat_interleave(1024).view(2048, 1024).float() + receiver._receiver._scratch_payloads = { + "trainer_r0": {"w": r0_buf}, + "trainer_r1": {"w": r1_buf}, + } + + plan = receiver.discover_v2_sources_for_slice( + model_name="m", + target_layout=v2.TargetTPLayout( + world_size=1, rank=0, shard_axis=0, target_range=(1500, 2500) + ), + same_rank_only=False, + ) + assert plan.fully_covered + + out = dict(receiver.receive_via_plan(plan)) + assert "w" in out + stitched = out["w"] + # Expect shape (1000, 1024); first 548 rows come from r0_buf[1500:2048], + # next 452 rows come from r1_buf[0:452] (which are global rows [2048,2500)). + assert stitched.shape == (1000, 1024) + expected = torch.cat([r0_buf[1500:2048], r1_buf[0:452]], dim=0) + assert torch.equal(stitched, expected) + + +def test_phase4_receive_via_plan_single_source_passthrough(v2): + """When one trainer covers the slice, no torch.cat happens.""" + import torch + + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=1 + ) + receiver.initialize() + blob_r0, blob_r1 = _build_two_trainers_tp4_mixed_tp8(v2) + _set_two_trainers(v2, receiver, blob_r0, blob_r1) + + r0_buf = torch.arange(0, 2048).repeat_interleave(1024).view(2048, 1024).float() + receiver._receiver._scratch_payloads = {"trainer_r0": {"w": r0_buf}} + + # Receiver TP=8 rank=1: wants rows [512, 1024) — fully inside trainer r0. + plan = receiver.discover_v2_sources_for_slice( + model_name="m", + target_layout=v2.TargetTPLayout(world_size=8, rank=1, shard_axis=0), + same_rank_only=False, + ) + assert plan.fully_covered + out = dict(receiver.receive_via_plan(plan)) + assert out["w"].shape == (512, 1024) + assert torch.equal(out["w"], r0_buf[512:1024]) + + +# ---------------------------------------------------------------------------- +# Phase 3 graduation glue — MxV2TrainingPublisher.add_tensor takes the new +# compile_target / compile_metadata kwargs and they flow into the registry. +# ---------------------------------------------------------------------------- + + +def test_phase3_add_tensor_threads_compile_target(v2): + """add_tensor's new compile_target + compile_metadata kwargs must surface + on the resulting TensorDescriptorV2 in the publisher's internal registry.""" + import torch + + sd = sys.modules["modelexpress.shape_descriptors"] + + # Stand up a publisher pointed at the stub MX client; we won't actually + # publish, just inspect the registry after add_tensor calls. + pub = v2.MxV2TrainingPublisher( + agent_name="t", + device_id=0, + mx_server_url="x", + worker_rank=0, + world_layout=v2.TrainerWorldLayout(fsdp_world_size=1), + heartbeat=False, + ) + pub.initialize(model_name="m", dtype="bfloat16") + + class _FakeCudaTensor: + # Minimal stand-in: ``describe_tensor`` reads .dtype, .shape, + # optional .placements; ``add_tensor`` checks .is_cuda. + def __init__(self, shape, dtype=torch.bfloat16): + self.shape = torch.Size(shape) + self.dtype = dtype + self.is_cuda = True + + pub.add_tensor( + name="lm_head.weight", + tensor=_FakeCudaTensor([2048, 4096]), + ) + pub.add_tensor( + name="model.layers.0.mlp.gate_proj.weight", + tensor=_FakeCudaTensor([512, 2048]), + compile_target=sd.COMPILE_TARGET_CUTLASS_FP8, + compile_metadata={ + "dtype": "e4m3", + "scale_layout": "per_channel", + "scale_axis": -1, + "activation_scheme": "dynamic", + }, + ) + pub.add_tensor( + name="model.layers.0.mlp.experts.weight", + tensor=_FakeCudaTensor([24, 4096, 12288]), + is_expert=True, + owned_expert_ids=(0, 1, 2, 3), + compile_target=sd.COMPILE_TARGET_DEEPGEMM_FP8, + compile_metadata={ + "dtype": "e4m3", + "scale_layout": "blockwise", + "block_size": [128, 128], + }, + ) + + by_name = {d.name: d for d in pub._registry} + assert by_name["lm_head.weight"].compile_target == sd.COMPILE_TARGET_HF_RAW + assert by_name["lm_head.weight"].compile_metadata == {} + + gp = by_name["model.layers.0.mlp.gate_proj.weight"] + assert gp.compile_target == sd.COMPILE_TARGET_CUTLASS_FP8 + assert gp.compile_metadata == { + "dtype": "e4m3", + "scale_layout": "per_channel", + "scale_axis": -1, + "activation_scheme": "dynamic", + } + + ex = by_name["model.layers.0.mlp.experts.weight"] + assert ex.compile_target == sd.COMPILE_TARGET_DEEPGEMM_FP8 + assert ex.compile_metadata == { + "dtype": "e4m3", + "scale_layout": "blockwise", + "block_size": [128, 128], + } + assert ex.is_expert + assert set(ex.owned_expert_ids) == {0, 1, 2, 3} + + +def test_phase3_add_tensor_compile_target_survives_encode_decode(v2): + """Round-trip: tagged tensors → encode_registry → decode_registry preserves + compile_target + compile_metadata. Asserts the wire format is intact end-to-end.""" + import torch + + sd = sys.modules["modelexpress.shape_descriptors"] + pub = v2.MxV2TrainingPublisher( + agent_name="t", + device_id=0, + mx_server_url="x", + worker_rank=0, + world_layout=v2.TrainerWorldLayout(fsdp_world_size=1), + heartbeat=False, + ) + pub.initialize(model_name="m", dtype="bfloat16") + + class _FakeCudaTensor: + def __init__(self, shape, dtype=torch.bfloat16): + self.shape = torch.Size(shape) + self.dtype = dtype + self.is_cuda = True + + pub.add_tensor( + name="w", + tensor=_FakeCudaTensor([64, 128]), + compile_target=sd.COMPILE_TARGET_CUTLASS_FP8, + compile_metadata={"dtype": "e4m3", "scale_layout": "per_channel"}, + ) + + blob = sd.encode_registry( + pub._registry, version=42, trainer_world_layout="fsdp:1" + ) + parsed = sd.decode_registry(blob) + out = parsed["tensors"][0] + assert out.compile_target == sd.COMPILE_TARGET_CUTLASS_FP8 + assert out.compile_metadata == {"dtype": "e4m3", "scale_layout": "per_channel"} + + +def test_phase4_receive_via_plan_rejects_uncovered(v2): + """receive_via_plan refuses to run a partial plan.""" + sd = sys.modules["modelexpress.shape_descriptors"] + receiver = v2.MxV2RefitReceiver( + agent_name="t", device_id=0, mx_server_url="x", worker_rank=0 + ) + receiver.initialize() + blob = _registry_blob( + v2, + [{ + "name": "w", + "global_shape": (4096,), + "placement_kind": sd.PLACEMENT_SHARD, + "shard_axis": 0, + "local_shard_range": (0, 1024), + }], + ) + response = MagicMock() + response.instances = [_fake_instance("m", "s", "only_one")] + metas = {"only_one": _fake_meta("trainer", 0, 7, 200, registry_blob=blob)} + receiver._receiver._client.list_sources.return_value = response + receiver._receiver._client.get_metadata = lambda mx_source_id, worker_id: metas[worker_id] + + plan = receiver.discover_v2_sources_for_slice( + model_name="m", + target_layout=v2.TargetTPLayout( + world_size=1, rank=0, shard_axis=0, target_range=(0, 4096) + ), + same_rank_only=False, + ) + with pytest.raises(RuntimeError, match="not fully covered"): + list(receiver.receive_via_plan(plan)) diff --git a/modelexpress_client/python/tests/test_vllm_weight_transfer.py b/modelexpress_client/python/tests/test_vllm_weight_transfer.py new file mode 100644 index 00000000..0690a440 --- /dev/null +++ b/modelexpress_client/python/tests/test_vllm_weight_transfer.py @@ -0,0 +1,564 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +"""Unit tests for the vLLM WeightTransferEngine adapter. + +These tests exercise the dispatch logic in ``MxWeightTransferEngine`` +without requiring vLLM, NIXL, or a live MX server. They follow the +same direct-load + stub pattern as ``test_v2_source_picker.py`` so the +suite runs on a plain CPU box. +""" + +from __future__ import annotations + +import importlib.util +import os +import sys +import types +from dataclasses import dataclass +from pathlib import Path +from unittest.mock import MagicMock + +import pytest + + +_HERE = Path(__file__).resolve().parent +_PKG_ROOT = _HERE.parent / "modelexpress" + + +def _load(modname: str, path: Path): + spec = importlib.util.spec_from_file_location(modname, path) + mod = importlib.util.module_from_spec(spec) + sys.modules[modname] = mod + spec.loader.exec_module(mod) + return mod + + +@pytest.fixture(scope="module") +def vllm_wt(): + """Load the vllm_weight_transfer module against the stubs we set up + for ``test_v2_source_picker.py``. We re-create those stubs here so + this test file is self-contained (no fixture order coupling). + + Crucially we set ``MX_WEIGHT_TRANSFER_AUTOREGISTER=0`` so the module + doesn't try to register with a (non-existent) vLLM at import time. + """ + os.environ["MX_WEIGHT_TRANSFER_AUTOREGISTER"] = "0" + + # Build the same stub modules ``test_v2_source_picker.py`` uses. + pkg = types.ModuleType("modelexpress") + pkg.__path__ = [str(_PKG_ROOT)] # type: ignore[attr-defined] + sys.modules["modelexpress"] = pkg + + p2p_pb2 = types.ModuleType("modelexpress.p2p_pb2") + p2p_pb2.SOURCE_STATUS_READY = 2 + p2p_pb2.SOURCE_STATUS_INITIALIZING = 1 + p2p_pb2.SOURCE_STATUS_STALE = 3 + p2p_pb2.MX_SOURCE_TYPE_WEIGHTS = 0 + p2p_pb2.BACKEND_FRAMEWORK_UNKNOWN = 0 + sys.modules["modelexpress.p2p_pb2"] = p2p_pb2 + + @dataclass + class _SourceIdentity: + model_name: str = "" + mx_source_type: int = 0 + backend_framework: int = 0 + tensor_parallel_size: int = 0 + pipeline_parallel_size: int = 0 + expert_parallel_size: int = 0 + dtype: str = "" + quantization: str = "" + + def __post_init__(self): + self.extra_parameters = {} + + @dataclass + class _WorkerMetadata: + worker_rank: int = 0 + nixl_metadata: bytes = b"" + tensors: list = None + status: int = 0 + agent_name: str = "" + + def __post_init__(self): + if self.tensors is None: + self.tensors = [] + + @dataclass + class _TensorDescriptor: + name: str = "" + addr: int = 0 + size: int = 0 + device_id: int = 0 + dtype: str = "" + + p2p_pb2.SourceIdentity = _SourceIdentity + p2p_pb2.WorkerMetadata = _WorkerMetadata + p2p_pb2.TensorDescriptor = _TensorDescriptor + + # Heartbeat stub + hb = types.ModuleType("modelexpress.heartbeat") + + class _HBStub: + def __init__(self, *a, **kw): + self.started = False + + def start(self): + self.started = True + + def stop(self): + self.started = False + + hb.HeartbeatThread = _HBStub + sys.modules["modelexpress.heartbeat"] = hb + + # MxRefitReceiver / MxTrainingPublisher stubs + refit_mod = types.ModuleType("modelexpress.refit_receiver") + + @dataclass + class _SourceRef: + mx_source_id: str = "" + worker_id: str = "" + model_name: str = "" + worker_rank: int = 0 + training_step: int = 0 + + @dataclass + class _TransferStats: + bytes_received: int = 0 + bytes_skipped: int = 0 + tensors_received: int = 0 + elapsed_seconds: float = 0.0 + bandwidth_gbps: float = 0.0 + discovery_seconds: float = 0.0 + path: str = "" + training_step: int = 0 + source_worker_rank: int | None = None + + class _RefitStub: + def __init__(self, *a, **kw): + self._client = MagicMock() + self._nixl = MagicMock() + self._agent_name = kw.get("agent_name", "stub") + self._worker_id = "stub-worker" + self.last_stats = _TransferStats() + self.history: list = [] + + def initialize(self, model_tensors=None): + pass + + def receive_weights(self, ref, timeout_seconds=300.0): + return iter([]) + + def receive_weights_scratch(self, ref, timeout_seconds=300.0, tensor_shapes=None): + return iter([]) + + refit_mod.MxRefitReceiver = _RefitStub + refit_mod.SourceRef = _SourceRef + refit_mod.TransferStats = _TransferStats + sys.modules["modelexpress.refit_receiver"] = refit_mod + + pub_mod = types.ModuleType("modelexpress.training_publisher") + + class _PubStub: + def __init__(self, *a, **kw): + self._client = None + self._nixl = None + self.mx_source_id = "abcd1234" + self.worker_id = "stub-pub-worker" + + def initialize(self, **kw): + pass + + def publish_weights(self, named_tensors, step, worker_rank): + return self.mx_source_id + + def mark_ready(self, worker_rank=0): + return True + + def shutdown(self): + pass + + def _build_identity(self, step): + return p2p_pb2.SourceIdentity() + + pub_mod.MxTrainingPublisher = _PubStub + sys.modules["modelexpress.training_publisher"] = pub_mod + + types_mod = types.ModuleType("modelexpress.types") + + @dataclass + class _TD: + name: str = "" + addr: int = 0 + size: int = 0 + device_id: int = 0 + dtype: str = "" + + types_mod.TensorDescriptor = _TD + sys.modules["modelexpress.types"] = types_mod + + # Now exec the real modules against these stubs. + sd = _load("modelexpress.shape_descriptors", _PKG_ROOT / "shape_descriptors.py") + pkg.shape_descriptors = sd # type: ignore[attr-defined] + v2 = _load("modelexpress.nemo_rl_v2", _PKG_ROOT / "nemo_rl_v2.py") + pkg.nemo_rl_v2 = v2 # type: ignore[attr-defined] + wt = _load( + "modelexpress.vllm_weight_transfer", _PKG_ROOT / "vllm_weight_transfer.py" + ) + return wt, v2, sd + + +def test_engine_construction_without_init_info(vllm_wt): + """Engine can be constructed with no init_info; init_transfer_engine + is called separately.""" + wt, _, _ = vllm_wt + engine = wt.MxWeightTransferEngine() + assert engine._receiver is None + assert engine._init_info is None + + +def test_engine_construction_with_init_info(vllm_wt): + """Engine can be constructed with init_info; receiver is built eagerly.""" + wt, _, _ = vllm_wt + init = wt.MxInitInfo( + mx_server_url="fake:8001", + model_name="m", + worker_rank=0, + agent_name="ag", + device_id=0, + ) + engine = wt.MxWeightTransferEngine(init_info=init) + assert engine._receiver is not None + assert engine._init_info is init + + +def test_receive_weights_without_init_raises(vllm_wt): + """Calling receive_weights before init_transfer_engine should error.""" + wt, _, _ = vllm_wt + engine = wt.MxWeightTransferEngine() + update = wt.MxUpdateInfo(version=1) + with pytest.raises(RuntimeError, match="init_transfer_engine"): + engine.receive_weights(update, load_weights=lambda batch: None) + + +def test_receive_weights_matched_tp_path(vllm_wt, monkeypatch): + """Matched-TP path: target_tp_layout=None → discover_v2_sources + + pick_best_source + receive_from. Verify load_weights is called with + yielded (name, tensor) pairs.""" + wt, v2, sd = vllm_wt + engine = wt.MxWeightTransferEngine( + init_info=wt.MxInitInfo( + mx_server_url="fake:8001", + model_name="m", + worker_rank=0, + agent_name="ag", + publish_self_as_replica=False, # disable to keep this test tight + ) + ) + + fake_candidate = MagicMock(name="V2SourceCandidate") + fake_candidate.ref = MagicMock() + monkeypatch.setattr( + engine._receiver, "discover_v2_sources", lambda **kw: [fake_candidate] + ) + monkeypatch.setattr( + engine._receiver, "pick_best_source", lambda *a, **kw: fake_candidate + ) + yielded = [("w1", "T1"), ("w2", "T2")] + monkeypatch.setattr( + engine._receiver, "receive_from_scratch", lambda c, **kw: iter(yielded) + ) + + received = [] + engine.receive_weights( + wt.MxUpdateInfo(version=42), + load_weights=lambda batch: received.extend(batch), + ) + assert received == yielded + + +def test_receive_weights_mixed_tp_phase4_path(vllm_wt, monkeypatch): + """Mixed-TP path: target_tp_layout set → discover_v2_sources_for_slice + + receive_via_plan. Verify the Phase-4 plan is built and stitching + happens.""" + wt, v2, sd = vllm_wt + engine = wt.MxWeightTransferEngine( + init_info=wt.MxInitInfo( + mx_server_url="fake:8001", + model_name="m", + worker_rank=0, + agent_name="ag", + publish_self_as_replica=False, + ) + ) + + fake_plan = MagicMock(name="SliceCoveragePlan") + fake_plan.fully_covered = True + fake_plan.missing = [] + monkeypatch.setattr( + engine._receiver, "discover_v2_sources_for_slice", lambda **kw: fake_plan + ) + yielded = [("w", "STITCHED_TENSOR")] + monkeypatch.setattr( + engine._receiver, + "receive_via_plan", + lambda plan, **kw: iter(yielded) if plan is fake_plan else iter([]), + ) + + received = [] + update = wt.MxUpdateInfo( + version=99, + target_tp_layout=v2.TargetTPLayout(world_size=8, rank=3, shard_axis=0), + ) + engine.receive_weights(update, load_weights=lambda batch: received.extend(batch)) + assert received == yielded + + +def test_receive_weights_phase4_uncovered_plan_raises(vllm_wt, monkeypatch): + """Phase-4 path: a partial slice plan (missing entries) raises before + any RDMA cycles are spent.""" + wt, v2, _ = vllm_wt + engine = wt.MxWeightTransferEngine( + init_info=wt.MxInitInfo( + mx_server_url="x", + model_name="m", + worker_rank=0, + agent_name="ag", + publish_self_as_replica=False, + ) + ) + bad_plan = MagicMock(name="SliceCoveragePlan") + bad_plan.fully_covered = False + bad_plan.missing = ["w: coverage gap"] + monkeypatch.setattr( + engine._receiver, "discover_v2_sources_for_slice", lambda **kw: bad_plan + ) + + with pytest.raises(RuntimeError, match="no covering source set"): + engine.receive_weights( + wt.MxUpdateInfo( + version=1, + target_tp_layout=v2.TargetTPLayout(world_size=2, rank=0), + ), + load_weights=lambda batch: None, + ) + + +def test_receive_weights_matched_no_candidates_raises(vllm_wt, monkeypatch): + """Matched path: when no source passes the filter, raise BEFORE + NIXL receive. This is the Phase-3b safety net for the + compile_target_filter case.""" + wt, _, _ = vllm_wt + engine = wt.MxWeightTransferEngine( + init_info=wt.MxInitInfo( + mx_server_url="x", + model_name="m", + worker_rank=0, + agent_name="ag", + publish_self_as_replica=False, + ) + ) + monkeypatch.setattr(engine._receiver, "discover_v2_sources", lambda **kw: []) + monkeypatch.setattr(engine._receiver, "pick_best_source", lambda *a, **kw: None) + + with pytest.raises(RuntimeError, match="no source matches filters"): + engine.receive_weights( + wt.MxUpdateInfo( + version=1, + compile_target_filter={"cutlass_fp8"}, + ), + load_weights=lambda batch: None, + ) + + +def test_receive_weights_compile_target_filter_threaded_through(vllm_wt, monkeypatch): + """The MxUpdateInfo's compile_target_filter and required_compile_metadata + must reach the receiver's discover_v2_sources call unchanged.""" + wt, _, sd = vllm_wt + engine = wt.MxWeightTransferEngine( + init_info=wt.MxInitInfo( + mx_server_url="x", + model_name="m", + worker_rank=0, + agent_name="ag", + publish_self_as_replica=False, + ) + ) + captured: dict[str, object] = {} + + def fake_discover(**kw): + captured.update(kw) + return [] + + monkeypatch.setattr(engine._receiver, "discover_v2_sources", fake_discover) + monkeypatch.setattr(engine._receiver, "pick_best_source", lambda *a, **kw: None) + + with pytest.raises(RuntimeError): + engine.receive_weights( + wt.MxUpdateInfo( + version=7, + compile_target_filter={sd.COMPILE_TARGET_CUTLASS_FP8}, + required_compile_metadata={"block_size": 128}, + same_rank_only=False, + dedup_freshest_per_rank=False, + ), + load_weights=lambda batch: None, + ) + assert captured["model_name"] == "m" + assert captured["min_version"] == 7 + assert captured["compile_target_filter"] == {sd.COMPILE_TARGET_CUTLASS_FP8} + assert captured["required_compile_metadata"] == {"block_size": 128} + assert captured["same_rank_only"] is False + + +def test_receive_weights_publishes_self_as_replica_when_enabled(vllm_wt, monkeypatch): + """With publish_self_as_replica=True, after a successful receive + the engine triggers tree fan-out.""" + wt, _, _ = vllm_wt + engine = wt.MxWeightTransferEngine( + init_info=wt.MxInitInfo( + mx_server_url="x", + model_name="m", + worker_rank=0, + agent_name="ag", + publish_self_as_replica=True, + ) + ) + cand = MagicMock() + monkeypatch.setattr(engine._receiver, "discover_v2_sources", lambda **kw: [cand]) + monkeypatch.setattr(engine._receiver, "pick_best_source", lambda *a, **kw: cand) + monkeypatch.setattr(engine._receiver, "receive_from_scratch", lambda *a, **kw: iter([])) + publish_calls = [] + + def fake_publish(*, version, model_name): + publish_calls.append((version, model_name)) + return "replica-source-id" + + monkeypatch.setattr(engine._receiver, "publish_self_as_source", fake_publish) + engine.receive_weights( + wt.MxUpdateInfo(version=11), load_weights=lambda batch: None + ) + assert publish_calls == [(11, "m")] + + +def test_receive_weights_publish_self_failure_is_swallowed(vllm_wt, monkeypatch): + """publish_self_as_replica failure must NOT propagate — it's a + best-effort optimization, not correctness.""" + wt, _, _ = vllm_wt + engine = wt.MxWeightTransferEngine( + init_info=wt.MxInitInfo( + mx_server_url="x", + model_name="m", + worker_rank=0, + agent_name="ag", + publish_self_as_replica=True, + ) + ) + cand = MagicMock() + monkeypatch.setattr(engine._receiver, "discover_v2_sources", lambda **kw: [cand]) + monkeypatch.setattr(engine._receiver, "pick_best_source", lambda *a, **kw: cand) + monkeypatch.setattr(engine._receiver, "receive_from_scratch", lambda *a, **kw: iter([])) + + def broken_publish(*, version, model_name): + raise RuntimeError("MX server unreachable") + + monkeypatch.setattr(engine._receiver, "publish_self_as_source", broken_publish) + + # Should NOT raise. + engine.receive_weights( + wt.MxUpdateInfo(version=11), load_weights=lambda batch: None + ) + + +def test_trainer_send_weights_threads_compile_target(vllm_wt, monkeypatch): + """Trainer-side classmethod: each tensor in the iterator gets + add_tensor'd with the compile_target + compile_metadata from + MxTrainerSendArgs, and finally publish(version) is called.""" + wt, v2, sd = vllm_wt + + added: list[dict] = [] + published_with_version: list[int] = [] + + class _RecordingPublisher: + def add_tensor(self, **kw): + added.append(kw) + + def publish(self, *, version): + published_with_version.append(version) + return "trainer-source-id" + + pub = _RecordingPublisher() + args = wt.MxTrainerSendArgs( + publisher=pub, + version=42, + compile_target=sd.COMPILE_TARGET_CUTLASS_FP8, + compile_metadata={"block_size": 128, "scale_layout": "per_channel"}, + expert_axis_map={"expert.w": 0}, + owned_expert_ids={"expert.w": (0, 1, 2, 3)}, + ) + + iterator = iter([ + ("w1", "FAKE_TENSOR_1"), + ("expert.w", "FAKE_TENSOR_2"), + ]) + out = wt.MxWeightTransferEngine.trainer_send_weights(iterator, args) + assert out == "trainer-source-id" + assert len(added) == 2 + assert added[0]["name"] == "w1" + assert added[0]["compile_target"] == sd.COMPILE_TARGET_CUTLASS_FP8 + assert added[0]["compile_metadata"] == { + "block_size": 128, + "scale_layout": "per_channel", + } + assert added[0]["is_expert"] is False + assert added[1]["name"] == "expert.w" + assert added[1]["is_expert"] is True + assert added[1]["expert_axis"] == 0 + assert added[1]["owned_expert_ids"] == (0, 1, 2, 3) + assert published_with_version == [42] + + +def test_metrics_surface_exposed(vllm_wt): + """The engine exposes last_transfer_stats / transfer_history / + last_discovery_seconds for benchmark consumers.""" + wt, _, _ = vllm_wt + engine = wt.MxWeightTransferEngine() + # Pre-init: graceful Nones / empties + assert engine.last_transfer_stats is None + assert engine.transfer_history == [] + assert engine.last_discovery_seconds == 0.0 + + engine.init_transfer_engine( + wt.MxInitInfo( + mx_server_url="x", + model_name="m", + worker_rank=0, + agent_name="ag", + ) + ) + # Post-init: surfaces are wired through the receiver + assert engine.last_transfer_stats is not None # the empty TransferStats + assert engine.last_transfer_stats.bytes_received == 0 + assert engine.transfer_history == [] + assert engine.last_discovery_seconds == 0.0 + + +def test_engine_is_registered_when_vllm_unavailable(vllm_wt): + """In environments without vLLM, _AUTOREGISTERED is False; the + engine is still usable directly. (We force this case via the env + var in the fixture.)""" + wt, _, _ = vllm_wt + # The class is exported regardless of registration outcome. + assert wt.MxWeightTransferEngine is not None + # Auto-registration was disabled via env var in the fixture; the + # engine should be usable but not necessarily registered. + assert isinstance(wt._AUTOREGISTERED, bool) + + +def test_engine_exposes_init_info_cls_and_update_info_cls(vllm_wt): + """The vLLM factory contract: each engine class declares its info + types as class attributes.""" + wt, _, _ = vllm_wt + assert wt.MxWeightTransferEngine.init_info_cls is wt.MxInitInfo + assert wt.MxWeightTransferEngine.update_info_cls is wt.MxUpdateInfo diff --git a/modelexpress_common/proto/p2p.proto b/modelexpress_common/proto/p2p.proto index a78f427b..c3753391 100644 --- a/modelexpress_common/proto/p2p.proto +++ b/modelexpress_common/proto/p2p.proto @@ -277,6 +277,13 @@ message GetMetadataResponse { // Echoed worker_id string worker_id = 4; + + // Source identity (mirrors the SourceIdentity that produced mx_source_id). + // Required by v2 (NemoRL) clients that store framework metadata + // (training_step, role, shape registry, ...) in extra_parameters. + // Pre-v2 clients ignore this field; populating it on existing servers is + // backward-compatible. + SourceIdentity identity = 5; } // ============================================================================ diff --git a/modelexpress_server/src/p2p/backend.rs b/modelexpress_server/src/p2p/backend.rs index 456ea69e..a843b15d 100644 --- a/modelexpress_server/src/p2p/backend.rs +++ b/modelexpress_server/src/p2p/backend.rs @@ -30,6 +30,15 @@ pub struct ModelMetadataRecord { pub model_name: String, pub workers: Vec, pub published_at: i64, + /// Full SourceIdentity that produced ``source_id``. + /// + /// Older records (written before this field was added) leave it ``None``. + /// New v2 RL clients (NemoRL `update_weights_via_mx`) read framework-level + /// state from `identity.extra_parameters` (training_step, role, shape + /// registry, dirty experts). Backends are responsible for round-tripping + /// the entire SourceIdentity message; pre-existing records continue to + /// work because the field is optional. + pub identity: Option, } /// Lightweight reference to a source worker (no tensor metadata). diff --git a/modelexpress_server/src/p2p/backend/kubernetes.rs b/modelexpress_server/src/p2p/backend/kubernetes.rs index db8bf969..d149614d 100644 --- a/modelexpress_server/src/p2p/backend/kubernetes.rs +++ b/modelexpress_server/src/p2p/backend/kubernetes.rs @@ -442,6 +442,11 @@ impl MetadataBackend for KubernetesBackend { model_name: cr.spec.model_name.clone(), workers, published_at, + // K8s CRD backend doesn't yet round-trip the full SourceIdentity. + // v2 RL clients fall back to the sidecar transport (synthetic + // TensorDescriptor named __mx_v2_meta__) until the CRD schema is + // extended; see modelexpress_client/python/modelexpress/nemo_rl_v2.py. + identity: None, })) } diff --git a/modelexpress_server/src/p2p/backend/redis.rs b/modelexpress_server/src/p2p/backend/redis.rs index bd838e1c..13bcdb17 100644 --- a/modelexpress_server/src/p2p/backend/redis.rs +++ b/modelexpress_server/src/p2p/backend/redis.rs @@ -54,6 +54,12 @@ struct SourceAttributesJson { pub dtype: String, #[serde(default)] pub quantization: String, + /// Framework-specific config from `SourceIdentity.extra_parameters`. + /// Required by v2 RL clients (NemoRL `update_weights_via_mx`) that stash + /// version, role, shape registry, etc. here. Older records (pre-v2) + /// deserialize to an empty map via `#[serde(default)]`. + #[serde(default)] + pub extra_parameters: std::collections::HashMap, } impl From<&SourceIdentity> for SourceAttributesJson { @@ -68,6 +74,26 @@ impl From<&SourceIdentity> for SourceAttributesJson { expert_parallel_size: id.expert_parallel_size, dtype: id.dtype.clone(), quantization: id.quantization.clone(), + extra_parameters: id.extra_parameters.clone(), + } + } +} + +impl SourceAttributesJson { + /// Round-trip back to a SourceIdentity proto. Used by GetMetadata to + /// populate ``GetMetadataResponse.identity``. + fn to_source_identity(&self) -> SourceIdentity { + SourceIdentity { + mx_version: self.mx_version.clone(), + mx_source_type: self.mx_source_type, + model_name: self.model_name.clone(), + backend_framework: self.backend_framework, + tensor_parallel_size: self.tensor_parallel_size, + pipeline_parallel_size: self.pipeline_parallel_size, + expert_parallel_size: self.expert_parallel_size, + dtype: self.dtype.clone(), + quantization: self.quantization.clone(), + extra_parameters: self.extra_parameters.clone(), } } } @@ -382,13 +408,16 @@ impl MetadataBackend for RedisBackend { return Ok(None); } - // Fetch model_name from the source index key's __attributes__ field. + // Fetch the full SourceAttributesJson from the source index key's + // __attributes__ field. This carries model_name, framework knobs, and + // (for v2 RL clients) extra_parameters. let source_key = format!("{}{}", keys::SOURCE_PREFIX, source_id); let attr_json: Option = conn.hget(&source_key, keys::ATTRIBUTES_FIELD).await?; - let model_name = attr_json + let attrs = attr_json .and_then(|v| serde_json::from_str::(&v).ok()) - .map(|a| a.model_name) .unwrap_or_default(); + let model_name = attrs.model_name.clone(); + let identity = Some(attrs.to_source_identity()); let mut workers: Vec = Vec::with_capacity(fields.len()); for value in fields.values() { @@ -410,6 +439,7 @@ impl MetadataBackend for RedisBackend { model_name, workers, published_at: 0, + identity, })) } diff --git a/modelexpress_server/src/p2p/service.rs b/modelexpress_server/src/p2p/service.rs index 8b929278..e24bddd6 100644 --- a/modelexpress_server/src/p2p/service.rs +++ b/modelexpress_server/src/p2p/service.rs @@ -179,6 +179,7 @@ impl P2pService for P2pServiceImpl { worker: None, mx_source_id: String::new(), worker_id: String::new(), + identity: None, })); } @@ -189,20 +190,23 @@ impl P2pService for P2pServiceImpl { { Ok(Some(record)) => { // Each worker_id maps to exactly one worker record; take the first. + let identity = record.identity.clone(); let worker = record.workers.into_iter().next().map(WorkerMetadata::from); let found = worker.is_some(); info!( - "GetMetadata '{}' (source_id={}, worker_id={}): {} tensors", + "GetMetadata '{}' (source_id={}, worker_id={}): {} tensors, identity_present={}", record.model_name, req.mx_source_id, req.worker_id, worker.as_ref().map_or(0, |w| w.tensors.len()), + identity.is_some(), ); Ok(Response::new(GetMetadataResponse { found, worker, mx_source_id: req.mx_source_id, worker_id: req.worker_id, + identity, })) } Ok(None) => { @@ -215,6 +219,7 @@ impl P2pService for P2pServiceImpl { worker: None, mx_source_id: req.mx_source_id, worker_id: req.worker_id, + identity: None, })) } Err(e) => { @@ -224,6 +229,7 @@ impl P2pService for P2pServiceImpl { worker: None, mx_source_id: String::new(), worker_id: String::new(), + identity: None, })) } } @@ -461,6 +467,7 @@ mod tests { worker_grpc_endpoint: String::new(), }], published_at: 1234567890, + identity: None, })) }); @@ -744,6 +751,7 @@ mod tests { model_name: "my-model".to_string(), workers: vec![], published_at: 0, + identity: None, })) }); diff --git a/modelexpress_server/src/p2p/state.rs b/modelexpress_server/src/p2p/state.rs index a0f23c51..51fc0a51 100644 --- a/modelexpress_server/src/p2p/state.rs +++ b/modelexpress_server/src/p2p/state.rs @@ -375,6 +375,7 @@ mod tests { }, ], published_at: 1234567890, + identity: None, }; assert_eq!(record.model_name, "meta-llama/Llama-3.1-70B");