From 501ebccb393076c3d0eb7313a30ff7179bcb1595 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Fri, 3 Apr 2026 10:35:36 -0700 Subject: [PATCH 01/25] docs: add RL weight update integration slide deck Presentation covering ModelExpress integration into RL post-training weight sync (refit) for NeMo RL, verl, and PRIME-RL. Includes HTML slideshow and standalone SVG diagrams for architecture, transfer flow, component stack, and framework comparison. Made-with: Cursor Signed-off-by: Kavin Krishnan --- docs/slides/diagram-architecture.svg | 99 +++ docs/slides/diagram-component-stack.svg | 148 ++++ docs/slides/diagram-framework-comparison.svg | 112 +++ docs/slides/diagram-rl-loop-bottleneck.svg | 82 ++ docs/slides/diagram-transfer-flow.svg | 88 ++ docs/slides/mx-rl-integration-slides.html | 793 +++++++++++++++++++ 6 files changed, 1322 insertions(+) create mode 100644 docs/slides/diagram-architecture.svg create mode 100644 docs/slides/diagram-component-stack.svg create mode 100644 docs/slides/diagram-framework-comparison.svg create mode 100644 docs/slides/diagram-rl-loop-bottleneck.svg create mode 100644 docs/slides/diagram-transfer-flow.svg create mode 100644 docs/slides/mx-rl-integration-slides.html diff --git a/docs/slides/diagram-architecture.svg b/docs/slides/diagram-architecture.svg new file mode 100644 index 00000000..9e77b4a4 --- /dev/null +++ b/docs/slides/diagram-architecture.svg @@ -0,0 +1,99 @@ + + + + + + + + + + MODELEXPRESS RL REFIT ARCHITECTURE + + + + TRAINING WORKERS + + + FSDP2 / Megatron-LM + + + WeightExtractor + + Gather params per bucket + + + MxTrainingPublisher + + Register tensors with NIXL + Publish metadata (gRPC) + + + NIXL Agent (per GPU) + UCX/RDMA + registration + + + + MX SERVER + + + gRPC P2P Coordination + + + Redis / K8s CRD + + + Version Tracking + + + Refit Coordination + + + + INFERENCE WORKERS + + + vLLM / SGLang + + + MxRefitReceiver + + Poll for new versions + Receive RDMA writes + + + NIXL Agent (per GPU) + + + Apply weights in-place + + + Resume rollout inference + + + + + gRPC + + + + gRPC + + + + + + + RDMA WRITE + + + + GPU VRAM ’ GPU VRAM · Bypasses CPU, disk, and collective overhead + + + Automatic fallback: RDMA ’ CUDA IPC ’ TCP + + diff --git a/docs/slides/diagram-component-stack.svg b/docs/slides/diagram-component-stack.svg new file mode 100644 index 00000000..5f23f3a0 --- /dev/null +++ b/docs/slides/diagram-component-stack.svg @@ -0,0 +1,148 @@ + + + + + + + COMPONENT ARCHITECTURE  NEW & EXISTING + + + + TRAINING WORKERS + + + RL FRAMEWORK + NeMo RL / verl / PRIME-RL + + + TRAINING BACKEND + FSDP2 / Megatron-LM + + + + + + WeightExtractor + Gather params per bucket (FSDP2 / Megatron) + NEW + + + + + + MxTrainingPublisher + Register tensors with NIXL + Publish metadata + step version (gRPC) + NEW + + + + + + ResharderPlugin + Gather-then-shard / Direct-match / Auto + NEW + + + + + + NIXL Agent (per GPU) + UCX backend · RDMA WRITE capability + EXISTING + + + + MODELEXPRESS SERVER (RUST) + + + gRPC P2P Service + EXISTING + + + Redis / K8s CRD Metadata + EXISTING + + + Refit Version Tracking + NEW + + + Bucket Coordination + NEW + + + Heartbeat / Stale Reaper + EXISTING + + + p2p.proto (extended) + MODIFIED + + + + INFERENCE WORKERS + + + INFERENCE ENGINE + vLLM / SGLang + + + MxRefitReceiver + Poll for new weight versions + Coordinate with MX Server + NEW + + + NIXL Agent (per GPU) + UCX · Pre-registered receive buffers + EXISTING + + + Apply Weights In-Place + FP8 quantization on receiver side + + + SyncPolicyController + Sync · One-step-off · Fully async + NEW + + + Resume Rollout Generation + + + + + + gRPC + + + + + gRPC + + + + + + + RDMA WRITE (per bucket) + + + + LEGEND + + New component + + Modified + + Existing (NIXL, gRPC, Redis) + + Metadata (gRPC) + + Data plane + diff --git a/docs/slides/diagram-framework-comparison.svg b/docs/slides/diagram-framework-comparison.svg new file mode 100644 index 00000000..4f3e256c --- /dev/null +++ b/docs/slides/diagram-framework-comparison.svg @@ -0,0 +1,112 @@ + + + + + + + + + + FRAMEWORK INTEGRATION COMPARISON + + + + NeMo RL + + + verl + + + PRIME-RL + + + + + Training Backend + + + DTensor / Megatron + + + FSDP / FSDP2 / Megatron + + + FSDP2 (EP/CP) + + + + Inference Backend + + + vLLM, SGLang + + + vLLM, SGLang, HF + + + vLLM + + + + Current Sync + + + ZMQ IPC / HTTP / NCCL + + + NCCL / CheckpointEngine + + + Filesystem + HTTP + + + + MX Insertion + + + refit_policy_ + generation() + + + CheckpointEngine + ABC (v0.7) + + + Orchestrator + relay_weights() + + + + Primary Benefit + + + RDMA replaces ZMQ/NCCL + + + Multi-node w/o filesystem + + + Eliminates disk I/O + + + + Orchestration + + + Ray actors + + + Ray (driver + workers) + + + Lightweight orchestrator + + + + Shared Abstraction: WeightSyncBackend + Framework-agnostic interface decoupling transport mechanism from sync policy + diff --git a/docs/slides/diagram-rl-loop-bottleneck.svg b/docs/slides/diagram-rl-loop-bottleneck.svg new file mode 100644 index 00000000..272801dd --- /dev/null +++ b/docs/slides/diagram-rl-loop-bottleneck.svg @@ -0,0 +1,82 @@ + + + + + + + + + + + + + + + + + + + ON-POLICY RL TRAINING LOOP + + + + ROLLOUT + Inference Engine + vLLM / SGLang + Generate trajectories + + + + + + + REWARD + Compute Rewards + Rule-based / Model RM + + + + + + + TRAINING + Policy Gradient Update + FSDP2 / Megatron-LM + GRPO / PPO / DAPO + + + + + + + WEIGHT SYNC (REFIT) + BOTTLENECK + + + ILLUSTRATIVE WALL-CLOCK BREAKDOWN + + + Rollout (40%) + + + Rew + + + Train (20%) + + + REFIT (30%) + + + + ▲ Up to 30-40% for 70B+ models + + + + Filesystem ~20s+ | NCCL ~10s | ZMQ IPC ~3-5s | MX RDMA ~5s (target for 70B) + + diff --git a/docs/slides/diagram-transfer-flow.svg b/docs/slides/diagram-transfer-flow.svg new file mode 100644 index 00000000..aead0cfe --- /dev/null +++ b/docs/slides/diagram-transfer-flow.svg @@ -0,0 +1,88 @@ + + + + + + + + + + RL REFIT TRANSFER FLOW (ONE TRAINING STEP) + + + + Training (rank k) + + + MX Server + + + Inference (rank k) + + + + + + + + + optimizer.step() + + + + WeightExtractor.extract() + + + + + PublishMetadata(step=N) + + + + + UpdateStatus(READY) + + + + serving rollout... + + + + + poll_for_update() + + + + + GetMetadata(rank k) + + + + add_remote_agent(NIXL blob) + + + + + + RDMA WRITE (bucket_0) + + + + apply_weights(bucket_0) + + + + repeat for bucket_1 ... bucket_N (bounded GPU memory) + + + + resume_inference() + + + + next training step + diff --git a/docs/slides/mx-rl-integration-slides.html b/docs/slides/mx-rl-integration-slides.html new file mode 100644 index 00000000..f828b583 --- /dev/null +++ b/docs/slides/mx-rl-integration-slides.html @@ -0,0 +1,793 @@ + + + + + +ModelExpress RL Weight Update Integration + + + + + + +
+ + +
+
+ NVIDIA NIXL + ModelExpress + RL Post-Training +
+

ModelExpress for
RL Weight Updates

+

+ Extending GPU-to-GPU RDMA transfers from inference scaling to the training→inference refit boundary in reinforcement learning post-training. +

+

+ Target frameworks: NeMo RL · verl · PRIME-RL +

+
April 2026 — Integration Design Overview
+
+ + +
+
The Problem
+

The Weight Sync Bottleneck

+ +

+ On-policy RL (GRPO, PPO, DAPO) alternates between rollout generation on inference GPUs and gradient updates on training GPUs. After every training step, updated weights must reach inference before the next rollout — this refit phase stalls both sides. +

+ +
+
Wall-clock time breakdown (illustrative)
+
+
Rollout
+
Rew
+
Train
+
REFIT
+
+
▲ Up to 30–40% of wall-clock for large models
+
+ +
Current refit latency (70B-class model, multi-node)
+
+
+
Filesystem (PRIME-RL)
+
~20s+
+
+
+
NCCL Broadcast (NeMo RL)
+
~10s
+
+
+
ZMQ IPC (NeMo RL, co-loc)
+
~3-5s
+
+
+
MX RDMA P2P (target)
+
~5s
+
+
+
+ + +
+
The Solution
+

ModelExpress for Training→Inference Refit

+ +

+ Extend MX from inference-to-inference P2P to the training→inference boundary. Training workers register updated weights with NIXL, publish metadata to the MX Server, and RDMA-WRITE directly into inference GPU memory — bypassing CPU, disk, and collective overheads. +

+ +
+

+ [ INSERT DIAGRAM: diagram-architecture.svg ] +

+
+
+
+
Training Workers
+
FSDP2 / Megatron
+
WeightExtractor
+
MxTrainingPublisher
+
NIXL Agent per GPU
+
+
+
+
MX Server
+
gRPC Metadata Coord
+
Redis / K8s CRD
+
Version Tracking
+
+
+
+
Inference Workers
+
vLLM / SGLang
+
MxRefitReceiver
+
NIXL Agent per GPU
+
+
+
+ RDMA WRITE — GPU VRAM → GPU VRAM — bypasses CPU & disk +
+
+
+ +
+
+
~5s
+
Llama-3.3-70B (140 GB)
+
+
+
~15s
+
DeepSeek-V3 MoE (681 GB)
+
+
+
0
+
CPU staging required
+
+
+
Auto
+
IPC → RDMA → TCP fallback
+
+
+
+ + +
+
Architecture
+

Component Deep-Dive

+ +
+ + + + TRAINING WORKERS + + + + RL FRAMEWORK + NeMo RL / verl / PRIME-RL + + + + TRAINING BACKEND + FSDP2 / Megatron-LM + + + + + + + + WeightExtractor + Gather params • Bucket iteration + NEW + + + + + + + + MxTrainingPublisher + Register tensors with NIXL + Publish metadata (gRPC) • Version tag + NEW + + + + + + + + ResharderPlugin + Gather-then-shard • Direct-match + NEW + + + + + + + + NIXL Agent + UCX backend • Per-GPU • RDMA WRITE + EXISTING + + + + MODELEXPRESS SERVER + + + gRPC P2P Service + EXISTING + + + Redis / K8s CRD Backend + EXISTING + + + Refit Version Tracking + NEW + + + Bucket Coordination + NEW + + + Heartbeat / Reaper + EXISTING + + + p2p.proto (extended) + MODIFIED + + + + INFERENCE WORKERS + + + + INFERENCE ENGINE + vLLM / SGLang + + + + MxRefitReceiver + Poll for new weight versions + Coordinate with MX Server + NEW + + + + NIXL Agent + UCX backend • Pre-registered buffers + EXISTING + + + + Apply Weights In-Place + FP8 quantization on receiver + + + + SyncPolicyController + Sync • One-step-off • Async + NEW + + + + Resume Rollout Generation + + + + + + gRPC + + + + + gRPC + + + + + + + RDMA WRITE (per bucket) + + + + LEGEND + + New component + + Modified + + Existing + + Metadata + + Data plane (RDMA) + +
+
+ + +
+
Integration Map
+

Three Frameworks, One Abstraction

+ +

+ A framework-agnostic WeightSyncBackend decouples the transfer mechanism from each framework's orchestration and sync policy. +

+ + + + + + + + + + + + + + + + + +
DimensionNeMo RLverlPRIME-RL
Training BackendDTensor / MegatronFSDP / FSDP2 / MegatronFSDP2 (EP/CP)
Inference BackendvLLM, SGLangvLLM, SGLang, HFvLLM
Current SyncZMQ IPC / HTTP / NCCLNCCL / CheckpointEngineFilesystem + HTTP
MX Insertion Pointrefit_policy_generation()CheckpointEngine ABCOrchestrator relay
Primary BenefitRDMA replaces ZMQ/NCCLMulti-node sans filesystemEliminates disk I/O
+ +
+
+

NeMo RL

+

New branch alongside ZMQ IPC and NCCL in refit function. Bucket-streamed transfer maps to MX publish. Ray actor integration.

+
+
+

verl

+

ModelExpressCheckpointEngine implements v0.7 CheckpointEngine ABC. Targets async server mode. Engine mode stays NCCL.

+
+
+

PRIME-RL

+

Replaces filesystem relay in orchestrator. For cross-DC, MX acts as fast intra-cluster delivery under SHARDCAST.

+
+
+
+ + +
+
Delivery
+

Phased Integration Plan

+ +
+
+
Phase 1 — Weeks 1-4
+

Foundation

+
    +
  • WeightExtractor abstraction (FSDP2 + Megatron)
  • +
  • MxTrainingPublisher with NIXL + gRPC
  • +
  • Proto extensions (RL_REFIT, training_step)
  • +
  • Single-node RDMA validation
  • +
  • Benchmark vs. ZMQ IPC baseline
  • +
+
+
+
Phase 2 — Weeks 5-10
+

Framework Integrations

+
    +
  • NeMo RL: MX branch in refit function
  • +
  • verl: ModelExpressCheckpointEngine
  • +
  • PRIME-RL: Orchestrator MX relay
  • +
  • Fallback to existing mechanisms
  • +
  • E2E GRPO/PPO correctness tests
  • +
+
+
+
Phase 3 — Weeks 11-13
+

Hardening

+
    +
  • ResharderPlugin (gather-then-shard)
  • +
  • Error handling + fallback paths
  • +
  • MoE bucket completion tracking
  • +
  • FP8 quantization validation
  • +
  • Multi-node benchmarks (70B, MoE)
  • +
+
+
+ +
+
+

Llama-3.1-8B

+

2 nodes · NCCL ~3s → MX ~1s

+
+
+

Llama-3.3-70B

+

4 nodes · ~10-20s → MX ~5s

+
+
+

DeepSeek-V3 MoE

+

8+ nodes · ~30s+ → MX ~15s

+
+
+ +
    +
  • Key risk: Parallelism layout mismatch — mitigated by ResharderPlugin and config alignment guidance
  • +
  • Dependency: InfiniBand for full RDMA benefit; clean fallback to NCCL/TCP/filesystem
  • +
+
+ +
+ +
01 / 06
+ + + + + + From 66a546aa6280dce02277bb2905f78e5b41dc8012 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Sat, 4 Apr 2026 14:29:40 -0700 Subject: [PATCH 02/25] docs: add markdown version of RL integration slide deck Mirrors all 6 slides from the HTML presentation in plain markdown for easier viewing on GitHub and compatibility with Marp/Slidev. Made-with: Cursor Signed-off-by: Kavin Krishnan --- docs/slides/mx-rl-integration-slides.md | 199 ++++++++++++++++++++++++ 1 file changed, 199 insertions(+) create mode 100644 docs/slides/mx-rl-integration-slides.md diff --git a/docs/slides/mx-rl-integration-slides.md b/docs/slides/mx-rl-integration-slides.md new file mode 100644 index 00000000..db1bd783 --- /dev/null +++ b/docs/slides/mx-rl-integration-slides.md @@ -0,0 +1,199 @@ +# ModelExpress for RL Weight Updates + +> **April 2026 β€” Integration Design Overview** +> +> `NVIDIA NIXL` Β· `ModelExpress` Β· `RL Post-Training` + +Extending GPU-to-GPU RDMA transfers from **inference scaling** to the **trainingβ†’inference refit boundary** in reinforcement learning post-training. + +Target frameworks: **NeMo RL** Β· **verl** Β· **PRIME-RL** + +--- + +## Slide 1 β€” Title + +**ModelExpress for RL Weight Updates** + +Extending GPU-to-GPU RDMA transfers from **inference scaling** to the **trainingβ†’inference refit boundary** in reinforcement learning post-training. + +Target frameworks: **NeMo RL** Β· **verl** Β· **PRIME-RL** + +--- + +## Slide 2 β€” The Problem: The Weight Sync Bottleneck + +On-policy RL (GRPO, PPO, DAPO) alternates between rollout generation on inference GPUs and gradient updates on training GPUs. After every training step, updated weights must reach inference before the next rollout β€” this **refit phase** stalls both sides. + +### Wall-clock time breakdown (illustrative) + +``` +| Rollout (40%) | Rew | Train (20%) | β–ˆβ–ˆ REFIT (30%) β–ˆβ–ˆ | + β–² BOTTLENECK β–² +``` + +> Up to 30–40% of wall-clock for 70B+ models + +### Current refit latency (70B-class model, multi-node) + +| Method | Latency | +|--------|---------| +| Filesystem (PRIME-RL) | ~20s+ | +| NCCL Broadcast (NeMo RL) | ~10s | +| ZMQ IPC (NeMo RL, co-located) | ~3-5s | +| **MX RDMA P2P (target)** | **~5s** | + +--- + +## Slide 3 β€” The Solution: ModelExpress for Trainingβ†’Inference Refit + +Extend MX from inference-to-inference P2P to the trainingβ†’inference boundary. Training workers register updated weights with NIXL, publish metadata to the MX Server, and RDMA-WRITE directly into inference GPU memory β€” bypassing CPU, disk, and collective overheads. + +### High-level data flow + +``` +Training Workers MX Server Inference Workers +(FSDP2 / Megatron) (gRPC + Redis/CRD) (vLLM / SGLang) + + WeightExtractor ──gRPC──► Metadata Coord ◄──gRPC── MxRefitReceiver + MxTrainingPublisher Version Tracking NIXL Agent + NIXL Agent + β”‚ β–² + └══════════════ RDMA WRITE (GPUβ†’GPU) β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β”˜ + bypasses CPU & disk +``` + +> *See: [diagram-architecture.svg](diagram-architecture.svg)* + +### Performance + +| Metric | Value | +|--------|-------| +| Llama-3.3-70B (140 GB) | **~5s** | +| DeepSeek-V3 MoE (681 GB) | **~15s** | +| CPU staging required | **0** | +| Transport fallback | **Auto**: IPC β†’ RDMA β†’ TCP | + +--- + +## Slide 4 β€” Architecture: Component Deep-Dive + +> *See: [diagram-component-stack.svg](diagram-component-stack.svg)* + +### Training Workers + +| Component | Status | Description | +|-----------|--------|-------------| +| RL Framework | Existing | NeMo RL / verl / PRIME-RL | +| Training Backend | Existing | FSDP2 / Megatron-LM | +| **WeightExtractor** | **NEW** | Gather params per bucket (FSDP2 / Megatron) | +| **MxTrainingPublisher** | **NEW** | Register tensors with NIXL, publish metadata (gRPC), version tag | +| **ResharderPlugin** | **NEW** | Gather-then-shard / Direct-match / Auto | +| NIXL Agent | Existing | UCX backend, per-GPU, RDMA WRITE | + +### ModelExpress Server (Rust) + +| Component | Status | Description | +|-----------|--------|-------------| +| gRPC P2P Service | Existing | PublishMetadata, ListSources, GetMetadata, UpdateStatus | +| Redis / K8s CRD Backend | Existing | Metadata persistence and HA | +| **Refit Version Tracking** | **NEW** | training_step in SourceIdentity; version-filtered ListSources | +| **Bucket Coordination** | **NEW** | RefitCoordination message for bucket-level progress | +| Heartbeat / Reaper | Existing | Stale source detection and GC | +| p2p.proto | **MODIFIED** | New enums, fields, messages for RL refit | + +### Inference Workers + +| Component | Status | Description | +|-----------|--------|-------------| +| Inference Engine | Existing | vLLM / SGLang | +| **MxRefitReceiver** | **NEW** | Poll for new weight versions, coordinate with MX Server | +| NIXL Agent | Existing | UCX backend, pre-registered receive buffers | +| Apply Weights In-Place | Existing | FP8 quantization on receiver side | +| **SyncPolicyController** | **NEW** | Sync / One-step-off / Fully async modes | +| Resume Rollout | Existing | Continue generation after refit | + +### Data paths + +- **Metadata (gRPC)**: Training β†’ MX Server β†’ Inference (dashed, control plane) +- **Data plane (RDMA WRITE)**: Training NIXL Agent β†’ Inference NIXL Agent (per bucket) + +--- + +## Slide 5 β€” Integration Map: Three Frameworks, One Abstraction + +A framework-agnostic **WeightSyncBackend** decouples the transfer mechanism from each framework's orchestration and sync policy. + +> *See: [diagram-framework-comparison.svg](diagram-framework-comparison.svg)* + +### Comparison + +| Dimension | NeMo RL | verl | PRIME-RL | +|-----------|---------|------|----------| +| Training Backend | DTensor / Megatron | FSDP / FSDP2 / Megatron | FSDP2 (EP/CP) | +| Inference Backend | vLLM, SGLang | vLLM, SGLang, HF | vLLM | +| Current Sync | ZMQ IPC / HTTP / NCCL | NCCL / CheckpointEngine | Filesystem + HTTP | +| **MX Insertion Point** | `refit_policy_generation()` | `CheckpointEngine` ABC | Orchestrator `relay_weights()` | +| Primary Benefit | RDMA replaces ZMQ/NCCL | Multi-node sans filesystem | Eliminates disk I/O | + +### Per-framework summary + +**NeMo RL** β€” New branch alongside ZMQ IPC and NCCL in refit function. Bucket-streamed transfer maps to MX publish. Ray actor integration. + +**verl** β€” `ModelExpressCheckpointEngine` implements v0.7 `CheckpointEngine` ABC. Targets async server mode. Engine mode stays NCCL (already optimal co-located). + +**PRIME-RL** β€” Replaces filesystem relay in orchestrator. For cross-DC (Intellect-2), MX acts as fast intra-cluster delivery under SHARDCAST. + +--- + +## Slide 6 β€” Delivery: Phased Integration Plan + +### Phase 1 β€” Weeks 1-4: Foundation + +- WeightExtractor abstraction (FSDP2 + Megatron) +- MxTrainingPublisher with NIXL + gRPC +- Proto extensions (`RL_REFIT`, `training_step`) +- Single-node RDMA validation +- Benchmark vs. ZMQ IPC baseline + +### Phase 2 β€” Weeks 5-10: Framework Integrations + +- NeMo RL: MX branch in refit function +- verl: `ModelExpressCheckpointEngine` +- PRIME-RL: Orchestrator MX relay +- Fallback to existing mechanisms +- E2E GRPO/PPO correctness tests + +### Phase 3 β€” Weeks 11-13: Hardening + +- ResharderPlugin (gather-then-shard) +- Error handling + fallback paths +- MoE bucket completion tracking +- FP8 quantization validation +- Multi-node benchmarks (70B, MoE) + +### Target Performance + +| Model | Nodes | Current | MX Target | +|-------|-------|---------|-----------| +| Llama-3.1-8B | 2 | NCCL ~3s | **MX ~1s** | +| Llama-3.3-70B | 4 | ~10-20s | **MX ~5s** | +| DeepSeek-V3 MoE | 8+ | ~30s+ | **MX ~15s** | + +### Key risks + +- **Parallelism layout mismatch** β€” mitigated by ResharderPlugin and config alignment guidance +- **InfiniBand dependency** β€” clean fallback to NCCL/TCP/filesystem when unavailable + +--- + +## Diagrams + +All diagrams are available as standalone SVG files for embedding in external slide tools: + +| File | Description | +|------|-------------| +| [diagram-rl-loop-bottleneck.svg](diagram-rl-loop-bottleneck.svg) | RL training loop with refit bottleneck highlighted | +| [diagram-architecture.svg](diagram-architecture.svg) | Three-column architecture: Training β†’ MX Server β†’ Inference | +| [diagram-component-stack.svg](diagram-component-stack.svg) | Full component stack with NEW/MODIFIED/EXISTING tags | +| [diagram-transfer-flow.svg](diagram-transfer-flow.svg) | Sequence diagram: one refit step (publish β†’ poll β†’ RDMA WRITE β†’ apply) | +| [diagram-framework-comparison.svg](diagram-framework-comparison.svg) | NeMo RL / verl / PRIME-RL comparison grid | From 9ce518ddf3eeec4d7492e5816d23be30e2e8d927 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Mon, 6 Apr 2026 19:10:47 -0700 Subject: [PATCH 03/25] feat: add MxTrainingPublisher and MxRefitReceiver for RL weight refit Training-side publisher registers updated model weights with NIXL and publishes metadata to the MX Server. Inference-side receiver discovers sources via ListSources, pulls weights via RDMA, and yields (name, tensor) pairs compatible with vLLM's load_weights(). Supports both all-at-once and layer-by-layer streaming patterns for PRIME-RL integration. Made-with: Cursor Signed-off-by: Kavin Krishnan --- .../python/modelexpress/__init__.py | 4 + .../python/modelexpress/refit_receiver.py | 298 ++++++++++++++++++ .../python/modelexpress/training_publisher.py | 272 ++++++++++++++++ 3 files changed, 574 insertions(+) create mode 100644 modelexpress_client/python/modelexpress/refit_receiver.py create mode 100644 modelexpress_client/python/modelexpress/training_publisher.py diff --git a/modelexpress_client/python/modelexpress/__init__.py b/modelexpress_client/python/modelexpress/__init__.py index f95c3f05..3bb46d99 100644 --- a/modelexpress_client/python/modelexpress/__init__.py +++ b/modelexpress_client/python/modelexpress/__init__.py @@ -71,12 +71,16 @@ def register_modelexpress_loaders(): from .gds_loader import MxGdsLoader # noqa: F401 from .gds_transfer import GdsTransferManager # noqa: F401 from .heartbeat import HeartbeatThread # noqa: F401 +from .training_publisher import MxTrainingPublisher # noqa: F401 +from .refit_receiver import MxRefitReceiver # noqa: F401 __all__ = [ "GdsTransferManager", "HeartbeatThread", "MxClient", "MxGdsLoader", + "MxRefitReceiver", + "MxTrainingPublisher", "configure_vllm_logging", "register_modelexpress_loaders", ] diff --git a/modelexpress_client/python/modelexpress/refit_receiver.py b/modelexpress_client/python/modelexpress/refit_receiver.py new file mode 100644 index 00000000..2d8e0ddf --- /dev/null +++ b/modelexpress_client/python/modelexpress/refit_receiver.py @@ -0,0 +1,298 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Inference-side weight receiver for RL refit via ModelExpress. + +Wraps NixlTransferManager + MxClient to discover updated weights +published by the training side, pull them via RDMA, and yield +``(name, tensor)`` pairs compatible with vLLM's ``model.load_weights()``. + +Typical usage in a vLLM worker extension:: + + receiver = MxRefitReceiver("inference-0", device_id=0, mx_server_url="mx-server:8001") + receiver.initialize(model_tensors=dict(model.named_parameters())) + + source = receiver.poll_for_source(model_name="Qwen/Qwen2.5-1.5B") + if source is not None: + for name, tensor in receiver.receive_weights(source): + ... # load into model +""" + +from __future__ import annotations + +import logging +import time +from dataclasses import dataclass +from typing import Iterator + +import torch + +from .client import MxClient +from .nixl_transfer import NixlTransferManager, is_nixl_available +from .types import TensorDescriptor +from . import p2p_pb2 + +logger = logging.getLogger("modelexpress.refit_receiver") + + +@dataclass +class SourceRef: + """Lightweight handle to a discovered weight source on the MX Server.""" + mx_source_id: str + worker_id: str + model_name: str + worker_rank: int + training_step: int + + +class MxRefitReceiver: + """Receives updated weights from a training process via ModelExpress RDMA. + + One instance per GPU rank on the inference side. Discovers training + sources via the MX Server, pulls weight tensors over NIXL RDMA, + and yields them for ``model.load_weights()``. + + Args: + agent_name: Unique NIXL agent name (e.g. ``"inference-rank-0"``). + device_id: CUDA device index for this inference rank. + mx_server_url: gRPC address of the ModelExpress server. + listen_port: Optional NIXL listen port for P2P metadata exchange. + """ + + def __init__( + self, + agent_name: str, + device_id: int, + mx_server_url: str = "localhost:8001", + listen_port: int | None = None, + ): + self._agent_name = agent_name + self._device_id = device_id + self._mx_server_url = mx_server_url + self._listen_port = listen_port + + self._nixl: NixlTransferManager | None = None + self._client: MxClient | None = None + self._initialized = False + self._current_step = -1 + + @property + def current_step(self) -> int: + """The most recently received training step.""" + return self._current_step + + def initialize(self, model_tensors: dict[str, torch.Tensor] | None = None) -> None: + """Initialize NIXL agent, MX client, and optionally register receive buffers. + + Args: + model_tensors: If provided, registers these tensors with NIXL as + receive buffers. For tensor-name-matched transfers, the source's + tensors are written directly into these buffers. If *None*, + the caller must register tensors separately. + """ + if not is_nixl_available(): + raise RuntimeError( + "NIXL is not available. Install nixl or build from source." + ) + + self._nixl = NixlTransferManager( + agent_name=self._agent_name, + device_id=self._device_id, + listen_port=self._listen_port, + ) + self._nixl.initialize() + + if model_tensors is not None: + self._nixl.register_tensors(model_tensors) + logger.info( + f"Registered {len(model_tensors)} receive buffers with NIXL" + ) + + self._client = MxClient(server_url=self._mx_server_url) + self._initialized = True + logger.info( + f"MxRefitReceiver initialized: agent={self._agent_name}, " + f"device={self._device_id}" + ) + + def poll_for_source( + self, + model_name: str, + min_step: int | None = None, + status_filter: int = p2p_pb2.SOURCE_STATUS_READY, + timeout_seconds: float = 0, + ) -> SourceRef | None: + """Check the MX Server for a training source with updated weights. + + Args: + model_name: Model name to filter on (must match publisher's identity). + min_step: If set, only return sources with ``training_step >= min_step``. + Defaults to ``current_step + 1`` to only find newer versions. + timeout_seconds: If > 0, poll repeatedly until a source is found + or timeout is reached. If 0, check once and return immediately. + + Returns: + A :class:`SourceRef` if a matching source was found, else *None*. + """ + if not self._initialized: + raise RuntimeError("Call initialize() before poll_for_source()") + + if min_step is None: + min_step = self._current_step + 1 + + identity = p2p_pb2.SourceIdentity( + model_name=model_name, + mx_source_type=p2p_pb2.MX_SOURCE_TYPE_WEIGHTS, + ) + + deadline = time.perf_counter() + timeout_seconds + + while True: + try: + response = self._client.list_sources( + identity=identity, + status_filter=status_filter, + ) + except Exception as e: + logger.warning(f"list_sources failed: {e}") + if time.perf_counter() >= deadline: + return None + time.sleep(0.5) + continue + + for instance in response.instances: + step_str = "" + try: + meta_resp = self._client.get_metadata( + mx_source_id=instance.mx_source_id, + worker_id=instance.worker_id, + ) + if meta_resp.found and meta_resp.worker: + worker = meta_resp.worker + if hasattr(worker, "tensors") and len(worker.tensors) > 0: + step_str = "" + for t in worker.tensors: + if t.name == "__training_step__": + step_str = t.dtype + break + except Exception: + pass + + source_step = int(step_str) if step_str.isdigit() else 0 + + if source_step >= min_step: + return SourceRef( + mx_source_id=instance.mx_source_id, + worker_id=instance.worker_id, + model_name=instance.model_name, + worker_rank=instance.worker_rank, + training_step=source_step, + ) + + if time.perf_counter() >= deadline: + return None + time.sleep(0.5) + + def receive_weights( + self, + source: SourceRef, + timeout_seconds: float = 300.0, + ) -> Iterator[tuple[str, torch.Tensor]]: + """Receive weights from a discovered source via NIXL RDMA. + + Fetches the source's NIXL metadata and tensor descriptors from the + MX Server, establishes an RDMA connection, and transfers weight + tensors into locally registered buffers. + + Args: + source: A :class:`SourceRef` obtained from :meth:`poll_for_source`. + timeout_seconds: Maximum time to wait for the RDMA transfer. + + Yields: + ``(name, tensor)`` pairs suitable for ``model.load_weights()``. + """ + if not self._initialized: + raise RuntimeError("Call initialize() before receive_weights()") + + meta_resp = self._client.get_metadata( + mx_source_id=source.mx_source_id, + worker_id=source.worker_id, + ) + if not meta_resp.found: + raise RuntimeError( + f"Source {source.mx_source_id}/{source.worker_id} not found on MX Server" + ) + + worker = meta_resp.worker + source_tensors = [ + TensorDescriptor( + name=t.name, + addr=t.addr, + size=t.size, + device_id=t.device_id, + dtype=t.dtype, + ) + for t in worker.tensors + ] + + transferred, skipped, elapsed = self._nixl.receive_from_source( + source_metadata=worker.nixl_metadata, + source_tensors=source_tensors, + timeout_seconds=timeout_seconds, + ) + + logger.info( + f"RDMA transfer complete: {transferred} bytes, " + f"{len(source_tensors)} tensors, {elapsed:.2f}s " + f"(step={source.training_step})" + ) + + self._current_step = source.training_step + + for td in source_tensors: + if td.name in self._nixl._tensors: + yield td.name, self._nixl._tensors[td.name] + + def receive_weights_from_metadata( + self, + nixl_metadata: bytes, + source_tensors: list[TensorDescriptor], + training_step: int, + timeout_seconds: float = 300.0, + ) -> Iterator[tuple[str, torch.Tensor]]: + """Receive weights when metadata is already available (bypasses MX Server query). + + Useful when the orchestrator passes metadata directly instead of + having the worker poll the MX Server. + """ + if not self._initialized: + raise RuntimeError("Call initialize() first") + + transferred, skipped, elapsed = self._nixl.receive_from_source( + source_metadata=nixl_metadata, + source_tensors=source_tensors, + timeout_seconds=timeout_seconds, + ) + + logger.info( + f"RDMA transfer (direct metadata): {transferred} bytes, " + f"{len(source_tensors)} tensors, {elapsed:.2f}s" + ) + + self._current_step = training_step + + for td in source_tensors: + if td.name in self._nixl._tensors: + yield td.name, self._nixl._tensors[td.name] + + def shutdown(self) -> None: + """Release NIXL agent and close gRPC channel.""" + if self._nixl is not None: + self._nixl.shutdown() + self._nixl = None + if self._client is not None: + self._client.close() + self._client = None + self._initialized = False + logger.info(f"MxRefitReceiver shut down: {self._agent_name}") diff --git a/modelexpress_client/python/modelexpress/training_publisher.py b/modelexpress_client/python/modelexpress/training_publisher.py new file mode 100644 index 00000000..11160955 --- /dev/null +++ b/modelexpress_client/python/modelexpress/training_publisher.py @@ -0,0 +1,272 @@ +# SPDX-FileCopyrightText: Copyright (c) 2025-2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Training-side weight publisher for RL refit via ModelExpress. + +Wraps NixlTransferManager + MxClient to register updated model weights +on the training GPU and publish metadata to the MX Server so that +inference workers can discover and pull them via RDMA. + +Typical usage in an RL training loop:: + + publisher = MxTrainingPublisher("trainer-0", device_id=0, mx_server_url="mx-server:8001") + publisher.initialize(model_name="Qwen/Qwen2.5-1.5B") + + # After optimizer.step(): + for layer_idx, layer_sd in enumerate_layers(model): + publisher.publish_layer(layer_sd, layer_idx, step=training_step) + publisher.mark_ready() +""" + +from __future__ import annotations + +import logging +import uuid +from typing import Iterator + +import torch + +from .client import MxClient +from .nixl_transfer import NixlTransferManager, is_nixl_available +from .types import TensorDescriptor +from . import p2p_pb2 + +logger = logging.getLogger("modelexpress.training_publisher") + + +class MxTrainingPublisher: + """Publishes updated model weights from a training process to ModelExpress. + + One instance per GPU rank. On the training side, after each optimizer step, + the publisher registers weight tensors with NIXL and publishes metadata to + the MX Server. Inference workers discover the source via ``ListSources`` + and pull weights via RDMA. + + Args: + agent_name: Unique NIXL agent name (e.g. ``"trainer-rank-0"``). + device_id: CUDA device index for this training rank. + mx_server_url: gRPC address of the ModelExpress server. + listen_port: Optional NIXL listen port for P2P metadata exchange. + """ + + def __init__( + self, + agent_name: str, + device_id: int, + mx_server_url: str = "localhost:8001", + listen_port: int | None = None, + ): + self._agent_name = agent_name + self._device_id = device_id + self._mx_server_url = mx_server_url + self._listen_port = listen_port + + self._nixl: NixlTransferManager | None = None + self._client: MxClient | None = None + self._worker_id: str = str(uuid.uuid4()) + self._mx_source_id: str | None = None + self._model_name: str = "" + self._initialized = False + + @property + def mx_source_id(self) -> str | None: + return self._mx_source_id + + @property + def worker_id(self) -> str: + return self._worker_id + + def initialize( + self, + model_name: str, + tensor_parallel_size: int = 1, + pipeline_parallel_size: int = 1, + expert_parallel_size: int = 1, + dtype: str = "bfloat16", + ) -> None: + """Initialize NIXL agent and MX client. + + Must be called before any publish operations. Sets up the source + identity that inference workers will use to filter compatible sources. + """ + if not is_nixl_available(): + raise RuntimeError( + "NIXL is not available. Install nixl or build from source." + ) + + self._model_name = model_name + self._identity_kwargs = dict( + model_name=model_name, + mx_source_type=p2p_pb2.MX_SOURCE_TYPE_WEIGHTS, + backend_framework=p2p_pb2.BACKEND_FRAMEWORK_UNKNOWN, + tensor_parallel_size=tensor_parallel_size, + pipeline_parallel_size=pipeline_parallel_size, + expert_parallel_size=expert_parallel_size, + dtype=dtype, + ) + + self._nixl = NixlTransferManager( + agent_name=self._agent_name, + device_id=self._device_id, + listen_port=self._listen_port, + ) + self._nixl.initialize() + + self._client = MxClient(server_url=self._mx_server_url) + self._initialized = True + logger.info( + f"MxTrainingPublisher initialized: agent={self._agent_name}, " + f"device={self._device_id}, model={model_name}" + ) + + def _build_identity(self, step: int) -> p2p_pb2.SourceIdentity: + """Build a SourceIdentity proto with the current training step.""" + return p2p_pb2.SourceIdentity( + extra_parameters={ + "training_step": str(step), + "training_framework": "prime_rl", + }, + **self._identity_kwargs, + ) + + def _build_tensor_protos( + self, descriptors: list[TensorDescriptor] + ) -> list[p2p_pb2.TensorDescriptor]: + return [ + p2p_pb2.TensorDescriptor( + name=d.name, + addr=d.addr, + size=d.size, + device_id=d.device_id, + dtype=d.dtype, + ) + for d in descriptors + ] + + def publish_weights( + self, + named_tensors: dict[str, torch.Tensor], + step: int, + worker_rank: int = 0, + ) -> str: + """Register tensors with NIXL and publish metadata to MX Server. + + This is the all-at-once variant. For layer-by-layer streaming, + use :meth:`publish_layer` instead. + + Args: + named_tensors: Mapping of parameter name to GPU tensor. + step: Current training step (used for version tracking). + worker_rank: GPU rank of this worker within the training group. + + Returns: + The ``mx_source_id`` (16-char hex) assigned by the server. + """ + if not self._initialized: + raise RuntimeError("Call initialize() before publish_weights()") + + self._nixl.register_tensors(named_tensors) + metadata = self._nixl.nixl_metadata + descriptors = self._nixl.tensor_descriptors + + identity = self._build_identity(step) + worker_meta = p2p_pb2.WorkerMetadata( + worker_rank=worker_rank, + nixl_metadata=metadata, + tensors=self._build_tensor_protos(descriptors), + status=p2p_pb2.SOURCE_STATUS_INITIALIZING, + agent_name=self._agent_name, + ) + + self._mx_source_id = self._client.publish_metadata( + identity=identity, + worker=worker_meta, + worker_id=self._worker_id, + ) + logger.info( + f"Published {len(named_tensors)} tensors for step {step} " + f"(mx_source_id={self._mx_source_id})" + ) + return self._mx_source_id + + def publish_layer( + self, + layer_state_dict: dict[str, torch.Tensor], + layer_idx: int, + step: int, + worker_rank: int = 0, + ) -> str: + """Publish a single layer's weights to MX Server. + + Designed for PRIME-RL's layer-by-layer streaming pattern where + ``filter_state_dict_by_layers()`` yields one layer at a time. + + Layer tensors are registered with NIXL (overwriting previous + registration), and metadata is published to the MX Server. The + inference side accumulates all layers before loading. + + Args: + layer_state_dict: Parameter name -> tensor for this layer. + layer_idx: Layer index (-1 for non-layer weights like embeddings). + step: Current training step. + worker_rank: GPU rank of this worker. + + Returns: + The ``mx_source_id`` assigned by the server. + """ + if not self._initialized: + raise RuntimeError("Call initialize() before publish_layer()") + + self._nixl.register_tensors(layer_state_dict) + metadata = self._nixl.nixl_metadata + descriptors = self._nixl.tensor_descriptors + + identity = self._build_identity(step) + identity.extra_parameters["layer_idx"] = str(layer_idx) + + worker_meta = p2p_pb2.WorkerMetadata( + worker_rank=worker_rank, + nixl_metadata=metadata, + tensors=self._build_tensor_protos(descriptors), + status=p2p_pb2.SOURCE_STATUS_INITIALIZING, + agent_name=self._agent_name, + ) + + self._mx_source_id = self._client.publish_metadata( + identity=identity, + worker=worker_meta, + worker_id=self._worker_id, + ) + logger.debug( + f"Published layer {layer_idx} ({len(layer_state_dict)} tensors) " + f"for step {step}" + ) + return self._mx_source_id + + def mark_ready(self, worker_rank: int = 0) -> bool: + """Signal that all layers/weights have been published and are ready. + + Inference workers filter on ``SOURCE_STATUS_READY`` when polling, + so this must be called after all publish calls for a given step. + """ + if self._mx_source_id is None: + raise RuntimeError("No weights published yet; call publish_weights() first") + + return self._client.update_status( + mx_source_id=self._mx_source_id, + worker_id=self._worker_id, + worker_rank=worker_rank, + status=p2p_pb2.SOURCE_STATUS_READY, + ) + + def shutdown(self) -> None: + """Release NIXL agent and close gRPC channel.""" + if self._nixl is not None: + self._nixl.shutdown() + self._nixl = None + if self._client is not None: + self._client.close() + self._client = None + self._initialized = False + logger.info(f"MxTrainingPublisher shut down: {self._agent_name}") From c34003ac8daec74c2887802a66644886e4c5dd02 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Thu, 9 Apr 2026 22:42:49 -0700 Subject: [PATCH 04/25] fix: list all sources and filter client-side by model_name (identity hash mismatch) Made-with: Cursor Signed-off-by: Kavin Krishnan --- .../python/modelexpress/refit_receiver.py | 43 +++++-------------- 1 file changed, 10 insertions(+), 33 deletions(-) diff --git a/modelexpress_client/python/modelexpress/refit_receiver.py b/modelexpress_client/python/modelexpress/refit_receiver.py index 2d8e0ddf..1bcbab8a 100644 --- a/modelexpress_client/python/modelexpress/refit_receiver.py +++ b/modelexpress_client/python/modelexpress/refit_receiver.py @@ -141,17 +141,11 @@ def poll_for_source( if min_step is None: min_step = self._current_step + 1 - identity = p2p_pb2.SourceIdentity( - model_name=model_name, - mx_source_type=p2p_pb2.MX_SOURCE_TYPE_WEIGHTS, - ) - deadline = time.perf_counter() + timeout_seconds while True: try: response = self._client.list_sources( - identity=identity, status_filter=status_filter, ) except Exception as e: @@ -162,33 +156,16 @@ def poll_for_source( continue for instance in response.instances: - step_str = "" - try: - meta_resp = self._client.get_metadata( - mx_source_id=instance.mx_source_id, - worker_id=instance.worker_id, - ) - if meta_resp.found and meta_resp.worker: - worker = meta_resp.worker - if hasattr(worker, "tensors") and len(worker.tensors) > 0: - step_str = "" - for t in worker.tensors: - if t.name == "__training_step__": - step_str = t.dtype - break - except Exception: - pass - - source_step = int(step_str) if step_str.isdigit() else 0 - - if source_step >= min_step: - return SourceRef( - mx_source_id=instance.mx_source_id, - worker_id=instance.worker_id, - model_name=instance.model_name, - worker_rank=instance.worker_rank, - training_step=source_step, - ) + if instance.model_name != model_name: + continue + + return SourceRef( + mx_source_id=instance.mx_source_id, + worker_id=instance.worker_id, + model_name=instance.model_name, + worker_rank=instance.worker_rank, + training_step=0, + ) if time.perf_counter() >= deadline: return None From 7e837c3b04c7fa0a6dbdd07fc20569a2470134f1 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Fri, 10 Apr 2026 12:26:43 -0700 Subject: [PATCH 05/25] fix: register NIXL tensors only once per publisher lifetime Tensor memory addresses don't change between optimizer steps, only values do. Calling register_memory every step accumulated descriptors, inflating the metadata blob from ~27 KB to 800+ KB and causing NIXL_ERR_NOT_ALLOWED on add_remote_agent. Made-with: Cursor Signed-off-by: Kavin Krishnan --- docs/165_review.md | 175 +++ docs/170_feedback.md | 148 +++ docs/feedback.md | 1160 +++++++++++++++++ docs/feedback_pr19920.md | 863 ++++++++++++ .../python/modelexpress/training_publisher.py | 13 +- 5 files changed, 2358 insertions(+), 1 deletion(-) create mode 100644 docs/165_review.md create mode 100644 docs/170_feedback.md create mode 100644 docs/feedback.md create mode 100644 docs/feedback_pr19920.md diff --git a/docs/165_review.md b/docs/165_review.md new file mode 100644 index 00000000..bca75e68 --- /dev/null +++ b/docs/165_review.md @@ -0,0 +1,175 @@ +# PR 165 Review: Metadata Resiliency Phase 1 + +Reviewer: KavinKrishnan +PR: https://github.com/ai-dynamo/modelexpress/pull/165 +Author: zhengluo-nv + +## Overall Assessment + +Good simplification. Merging ready state into WorkerRecord and eliminating the +memory/layered backends reduces code paths and configuration permutations +significantly. The UpdateStatus RPC is cleaner than the old +PublishReady/GetReady pair. Tests are solid. + +Main concerns: (1) the stability_verified removal breaks our TRT-LLM +DeepGEMM warmup workflow, (2) the retry-on-RDMA-failure path in +vllm_loader.py does not check status before re-using stale workers, and +(3) a few edge cases in the K8s backend can cause silent data loss. + +## Comments to Leave on PR + +### 1. BLOCKING - stability_verified removal breaks DeepGEMM warmup gating + +File: modelexpress_common/proto/p2p.proto, lines 62-67 (new WorkerMetadata fields) +Also: modelexpress_server/src/k8s_types.rs, lines 66-80 (new WorkerStatus struct) + +The old stability_verified field was used to gate P2P transfers until after +DeepGEMM warmup completes on the source. For DeepSeek V3 / Kimi K2.5, this +warmup takes 30-60 seconds and writes to GPU memory. Transferring weights +before it finishes produces corrupted inference. + +The new SourceStatus enum only has Initializing, Ready, Stale. There is +no state between "metadata published" and "fully warmed up and safe to transfer." + +Suggestion: Add a SOURCE_STATUS_PENDING_VERIFICATION = 4 state (as Zheng +proposed in the PR comments), or split Ready into METADATA_READY and +SERVING_READY. The source should transition: +Initializing -> PendingVerification -> Ready. Targets should only transfer +from workers in Ready status. This makes stability_verified expressible +via the status enum without needing a separate boolean. + +### 2. IMPORTANT - Target retry loop does not filter by worker status + +File: modelexpress_client/python/modelexpress/vllm_loader.py, lines 476-490 +(retry metadata refresh inside the transfer attempt loop) + +When an RDMA transfer fails and the target re-fetches metadata, it matches +workers only by worker_rank and len(w.tensors) > 0: + + response = self._mx_client.get_metadata(model_name) + for w in response.workers: + if w.worker_rank == device_id and len(w.tensors) > 0: + source_worker = w + +This does not check w.status == SOURCE_STATUS_READY. If the source restarted +and is in Initializing or Stale state, the target will attempt RDMA against +potentially invalid GPU addresses. + +The initial detection at _detect_source_worker (line ~353) correctly does: + + ready = p2p_pb2.SOURCE_STATUS_READY + for w in metadata_resp.workers: + if w.worker_rank == device_id and w.status == ready and len(w.tensors) > 0: + +So this is just the retry path missing the identical check. + +### 3. IMPORTANT - update_status call not wrapped in error handling + +File: modelexpress_client/python/modelexpress/vllm_loader.py, lines 212-219 + +After successfully publishing metadata, the source calls update_status but +does not check the return value: + + if success: + logger.info(f"[Worker {device_id}] Published metadata to MX server") + mx_client.update_status( + model_name=model_name, + worker_id=device_id, + status=p2p_pb2.SOURCE_STATUS_READY, + ) + +If this gRPC call fails (network blip, server restart), update_status +returns False but execution continues. The source thinks it published +READY, but targets polling GetMetadata will never see Ready status for +this worker -- they will see Initializing (or whatever status was set +during publish_metadata) and skip it. + +Suggestion: Check the return value and raise on failure: + + if not mx_client.update_status(...): + raise RuntimeError( + f"[Worker {device_id}] Failed to update status to READY" + ) + +### 4. NIT - K8s update_status silently returns Ok when worker not found + +File: modelexpress_server/src/metadata_backend/kubernetes.rs (update_status fn) + +When a worker ID does not exist in the CR's worker list, the K8s backend +logs at debug level and returns Ok(()): + + } else { + debug!( + "update_status: worker {} not found in CR '{}', skipping", + worker_id, cr_name + ); + return Ok(()); + } + +The Redis backend returns Err for the same case (Lua script returns 0, +check_patched converts to error). This inconsistency means callers cannot +distinguish "status updated" from "worker not found" on the K8s backend. + +Suggestion: Return Err to match Redis, or if the intent is to be lenient +(worker calls update_status before publish_metadata arrives), document +that and make the Redis backend match by returning Ok when patched == 0. + +### 5. NIT - status_proto_from_name rejects Unknown -- breaks CRD backward compat + +File: modelexpress_server/src/k8s_types.rs, lines 83-92 + +status_proto_from_name returns None for "Unknown", and the K8s backend +get_metadata converts None into a hard error. But the CRD schema defaults +status to "Unknown", so pre-existing CRs will fail to read. + +Suggestion: Map "Unknown" to Some(0) since proto defines SOURCE_STATUS_UNKNOWN = 0. + +### 6. MINOR - CRD lost all useful printer columns except Model and Age + +File: examples/p2p_transfer_k8s/deploy/persistence/crd-modelmetadata.yaml, lines 110-115 + +kubectl get modelmetadata now only shows Model and Age. Add back Workers count +and a Status summary column. + +### 7. MINOR - metadata.md just has WIP banner but keeps 600 lines of stale content + +File: docs/metadata.md, lines 1-3 + +Either update to match new architecture or delete and point to ARCHITECTURE.md. +Stale doc with one-line disclaimer is worse than no doc. + +### 8. MINOR - Dead condition types remain in CRD schema + +File: examples/p2p_transfer_k8s/deploy/persistence/crd-modelmetadata.yaml, lines 81-82 + +AllWorkersPublished and Ready conditions are defined in schema but nothing in +code populates them anymore. Remove or re-implement. + +### 9. NIT - main.rs errors do not identify which backend failed + +File: modelexpress_server/src/main.rs, lines 104-113 + +Error messages say "P2P metadata backend" without naming which backend or +connection target. Include MX_METADATA_BACKEND value in the message. + +### 10. QUESTION - Local dev story without in-memory backend + +File: layered.rs (deleted), memory.rs (deleted) + +MX_METADATA_BACKEND is now required. Local dev needs Redis or K8s. +Document the recommended local setup (Docker Compose with Redis sidecar?). + +## Summary Table + +| # | Severity | File | Lines | Topic | +|---|----------|------|-------|-------| +| 1 | BLOCKING | p2p.proto, k8s_types.rs | 62-67, 66-80 | stability_verified removal | +| 2 | IMPORTANT | vllm_loader.py | 476-490 | Retry loop missing status check | +| 3 | IMPORTANT | vllm_loader.py | 212-219 | update_status failure ignored | +| 4 | NIT | kubernetes.rs | 500-510 | Inconsistent Ok vs Err | +| 5 | NIT | k8s_types.rs | 83-92 | Unknown breaks backward compat | +| 6 | MINOR | crd-modelmetadata.yaml | 110-115 | Printer columns removed | +| 7 | MINOR | metadata.md | 1-3 | Stale doc | +| 8 | MINOR | crd-modelmetadata.yaml | 81-82 | Dead conditions | +| 9 | NIT | main.rs | 104-113 | Non-descriptive errors | +| 10 | QUESTION | layered.rs, memory.rs | deleted | Local dev story | diff --git a/docs/170_feedback.md b/docs/170_feedback.md new file mode 100644 index 00000000..90f2c121 --- /dev/null +++ b/docs/170_feedback.md @@ -0,0 +1,148 @@ +# PR 170 Review: Multi-Source P2P Metadata with Per-Worker APIs + +Reviewer: KavinKrishnan +PR: https://github.com/ai-dynamo/modelexpress/pull/170 +Author: zhengluo-nv + +## Overall Assessment + +Strong architectural redesign. The move from model-name keys to content-addressed +SourceIdentity (mx_source_id), per-worker publish/get, and ListSources RPC +correctly supports multiple concurrent source replicas. The two-step +ListSourcesβ†’GetMetadata flow with worker_rank filtering eliminates fan-out +RPCs. SourceTransferError for selective STALE marking is the right approach. +K8s update_status now returns Err on missing worker (CodeRabbit fix applied). + +Main concerns: (1) update_status failure in _publish_metadata_and_ready is +silently ignored, (2) TensorDescriptor lacks shape field needed for TRT-LLM +tensor reconstruction (main has it), (3) no PENDING_VERIFICATION state for +DeepGEMM warmup gating, and (4) a few doc/CRD cleanups. + +## Comments to Leave on PR + +### 1. IMPORTANT - update_status failure silently ignored in _publish_metadata_and_ready + +File: modelexpress_client/python/modelexpress/vllm_loader.py, lines 265-275 + +After successfully publishing metadata, the source calls update_status with +SOURCE_STATUS_READY. If this gRPC call fails (network blip, server restart), +the code only logs and continues: + +```python +success = mx_client.update_status( + mx_source_id=mx_source_id, + worker_id=worker_id, + worker_rank=global_rank, + status=p2p_pb2.SOURCE_STATUS_READY, +) +if not success: + logger.error( + f"[Worker {global_rank}] UpdateStatus to READY failed for " + f"model '{identity.model_name}' (mx_source_id={mx_source_id})" + ) +``` + +The source thinks it is ready, but targets never see Ready status and will +never discover this worker. Same issue as PR 165 #3. + +Suggestion: Check the return value and raise on failure so the source retries +or fails loudly instead of advertising readiness that targets cannot use. + +### 2. IMPORTANT - TensorDescriptor missing shape field + +File: modelexpress_common/proto/p2p.proto, lines 91-104 (TensorDescriptor message) + +PR 170's TensorDescriptor has name, addr, size, device_id, dtype but no shape. +Main branch (and PR 169) added `repeated int64 shape = 6` for proper tensor +reconstruction on the target. TRT-LLM and some vLLM models need shape to +correctly rebuild tensors after RDMA receive. + +Suggestion: Add `repeated int64 shape = 6` to TensorDescriptor and regenerate +stubs. Ensure vllm_loader and trtllm_loader pass shape when building +TensorDescriptor protos. + +### 3. BLOCKING (for TRT-LLM) - No PENDING_VERIFICATION state for DeepGEMM warmup + +File: modelexpress_common/proto/p2p.proto, lines 112-117 (SourceStatus enum) +Also: modelexpress_server/src/k8s_types.rs, lines 89-98 (status_name_from_proto) + +The SourceStatus enum has Unknown, Initializing, Ready, Stale. There is no +state between "metadata published" and "fully warmed up and safe to transfer." +For TRT-LLM DeepGEMM warmup (DeepSeek V3, Kimi K2.5), warmup takes 30-60 seconds +and writes to GPU memory. Transferring before it finishes produces corrupted +inference. + +Commit c75a58e had PENDING_VERIFICATION but a6cbdf5 reverted it to Unknown. +Suggestion: Re-add SOURCE_STATUS_PENDING_VERIFICATION = 4 (or use value that +does not shift Ready/Stale). Source transitions: Initializing -> +PendingVerification -> Ready. Targets only transfer from Ready. + +### 4. NIT - validate_identity only checks model_name + +File: modelexpress_server/src/source_identity.rs, lines 25-30 + +validate_identity only checks identity.model_name. SourceIdentity includes +backend_framework and mx_source_type. backend_framework=0 (UNKNOWN) may +indicate uninitialized or malformed identity. + +Suggestion (optional): Add validation for backend_framework when +BACKEND_FRAMEWORK_UNKNOWN should never be published. Return Err with clear +message so malformed identities are rejected early. + +### 5. MINOR - CRD printer columns reduced to Model and Age + +File: examples/p2p_transfer_k8s/deploy/persistence/crd-modelmetadata.yaml, lines 119-126 + +kubectl get modelmetadata now only shows Model and Age. Add back Workers count +and optionally a Status summary column for easier debugging. + +### 6. MINOR - Dead condition types in CRD schema + +File: examples/p2p_transfer_k8s/deploy/persistence/crd-modelmetadata.yaml, lines 84-86 + +AllWorkersPublished and Ready conditions are defined in the schema enum but +nothing in code populates them. Remove or re-implement. + +### 7. NIT - Docstring coverage below threshold + +Pre-merge check reports docstring coverage 62.88% (required 80%). Add +docstrings for functions missing them to satisfy the threshold. + +### 8. QUESTION - Stale source detection latency (~35s per dead source) + +PR description notes: CRDs from dead pods remain "Ready" until a new target +tries them and gets NIXL_ERR_REMOTE_DISCONNECT; UCX connection timeout is +~35s per stale source. Is there a plan for heartbeat/TTL to mark stale workers +automatically? Document as known limitation or track as follow-up. + +### 9. NIT - _collect_cuda_tensors vs _iter_module_tensors + +File: modelexpress_client/python/modelexpress/vllm_loader.py + +PR 170 uses _collect_cuda_tensors (named_parameters only) instead of the +main-branch _iter_module_tensors which also finds buffers and tensor +attributes (e.g. FP8 scale_inv). For FP8 models, scale tensors may be +missed. Verify this is intentional or restore the more thorough traversal. + +### 10. MINOR - main.rs errors do not identify which backend failed + +File: modelexpress_server/src/main.rs (if present in PR 170) + +Error messages that say "P2P metadata backend" without naming which backend +or connection target make debugging harder. Include MX_METADATA_BACKEND value +(or equivalent) in the message. + +## Summary Table + +| # | Severity | File | Lines | Topic | +|---|----------|------|-------|-------| +| 1 | IMPORTANT | vllm_loader.py | 265-275 | update_status failure ignored | +| 2 | IMPORTANT | p2p.proto | 91-104 | TensorDescriptor missing shape | +| 3 | BLOCKING (TRT-LLM) | p2p.proto, k8s_types.rs | 112-117, 89-98 | No PendingVerification for warmup | +| 4 | NIT | source_identity.rs | 25-30 | validate_identity scope | +| 5 | MINOR | crd-modelmetadata.yaml | 119-126 | Printer columns | +| 6 | MINOR | crd-modelmetadata.yaml | 84-86 | Dead conditions | +| 7 | NIT | (various) | β€” | Docstring coverage | +| 8 | QUESTION | β€” | β€” | Stale detection / heartbeat | +| 9 | NIT | vllm_loader.py | _collect_cuda_tensors | FP8 scale tensors | +| 10 | MINOR | main.rs | β€” | Non-descriptive backend errors | diff --git a/docs/feedback.md b/docs/feedback.md new file mode 100644 index 00000000..ac455abf --- /dev/null +++ b/docs/feedback.md @@ -0,0 +1,1160 @@ +# PR 157: Add TransferEngine Backend to P2P Metadata - Design Review & Feedback + +## Executive Summary + +This document provides a design overview and feedback for PR 157, which adds TransferEngine backend support to ModelExpress's P2P metadata system. The review is informed by: +- Current ModelExpress P2P metadata architecture (NIXL-based) +- SGLang's R-Fork implementation using TransferEngine +- Best practices for multi-backend transfer systems + +## Current Architecture Overview + +### Existing P2P Metadata System + +ModelExpress currently supports P2P weight transfers using **NIXL** (NVIDIA Inter-Node eXchange Library) for RDMA-based GPU-to-GPU transfers: + +1. **Metadata Structure**: + - `WorkerMetadata` contains `nixl_metadata` (byte blob) + tensor descriptors + - Metadata is published via gRPC to ModelExpress server + - Server stores metadata in Redis/Kubernetes/In-memory backends + +2. **Transfer Flow**: + - Source: Loads model β†’ Registers tensors with NIXL β†’ Publishes metadata β†’ Signals ready + - Target: Queries metadata β†’ Adds remote NIXL agents β†’ Executes RDMA transfers + +3. **Backend Abstraction**: + - Server-side: `MetadataBackend` trait (Memory/Redis/Kubernetes) + - Client-side: `NixlTransferManager` for NIXL operations + +## Proposed Design: TransferEngine Backend Support + +### Design Goals (Inferred from SGLang R-Fork) + +Based on [SGLang's R-Fork documentation](https://raw.githubusercontent.com/sgl-project/sglang/main/docs/advanced_features/rfork.md), TransferEngine support should: + +1. **Enable zero-copy weight loading** from running instances +2. **Support multiple backends**: NCCL, TransferEngine (and potentially NIXL) +3. **Backend selection** based on availability and configuration +4. **Metadata routing** to appropriate backend based on backend type + +### Expected Changes + +PR 157 likely introduces: + +1. **Protocol Buffer Updates** (`p2p.proto`): + - Add `backend_type` field to `WorkerMetadata` (enum: NIXL, TRANSFER_ENGINE, NCCL) + - Add TransferEngine-specific metadata fields (connection info, ports, etc.) + - Maintain backward compatibility with existing NIXL-only deployments + +2. **Server-Side Changes**: + - Extend `WorkerRecord` to store backend type + - Update metadata serialization/deserialization + - Ensure backend-agnostic storage (metadata backend should not care about transfer backend) + +3. **Client-Side Changes**: + - Add `TransferEngineTransferManager` (parallel to `NixlTransferManager`) + - Backend selection logic (NIXL vs TransferEngine) + - TransferEngine-specific connection establishment + +## Design Feedback & Recommendations + +### 1. Protocol Buffer Design + +#### βœ… **Recommendation: Use OneOf for Backend-Specific Metadata** + +**Current Approach (Inferred)**: +```protobuf +message WorkerMetadata { + uint32 worker_rank = 1; + bytes nixl_metadata = 2; // Only NIXL + repeated TensorDescriptor tensors = 3; +} +``` + +**Recommended Approach**: +```protobuf +message WorkerMetadata { + uint32 worker_rank = 1; + + // Backend type determines which metadata field is populated + BackendType backend_type = 2; + + // Backend-specific metadata (one of these is populated) + oneof backend_metadata { + NixlBackendMetadata nixl_metadata = 3; + TransferEngineBackendMetadata transfer_engine_metadata = 4; + NcclBackendMetadata nccl_metadata = 5; // Future-proofing + } + + repeated TensorDescriptor tensors = 6; +} + +enum BackendType { + BACKEND_TYPE_UNSPECIFIED = 0; + BACKEND_TYPE_NIXL = 1; + BACKEND_TYPE_TRANSFER_ENGINE = 2; + BACKEND_TYPE_NCCL = 3; +} + +message NixlBackendMetadata { + bytes nixl_agent_metadata = 1; // Serialized NIXL agent blob +} + +message TransferEngineBackendMetadata { + // Connection information for TransferEngine + string seed_instance_ip = 1; + uint32 seed_instance_service_port = 2; + repeated uint32 send_weights_group_ports = 3; // For NCCL backend + // Additional TransferEngine-specific fields as needed +} +``` + +**Rationale**: +- **Type Safety**: Clear separation of backend-specific metadata +- **Extensibility**: Easy to add new backends (NCCL, custom) +- **Backward Compatibility**: Can deprecate old `nixl_metadata` field gradually +- **Validation**: Server can validate that backend_type matches populated metadata + +#### ⚠️ **Concern: Backward Compatibility** + +**Issue**: Existing deployments use `bytes nixl_metadata`. How does PR 157 handle migration? + +**Recommendations**: +1. **Deprecation Strategy**: Keep `nixl_metadata` field but mark as deprecated +2. **Migration Path**: Server should accept both old and new formats during transition +3. **Auto-Detection**: If `backend_type` is unset but `nixl_metadata` is present, infer `BACKEND_TYPE_NIXL` + +**Example Migration Code**: +```rust +impl From for WorkerRecord { + fn from(meta: WorkerMetadata) -> Self { + let (backend_type, metadata_bytes) = match meta.backend_type { + BackendType::Nixl | BackendType::Unspecified => { + // Handle legacy: if backend_type unset but nixl_metadata present + if !meta.nixl_metadata.is_empty() { + (BackendType::Nixl, meta.nixl_metadata) + } else if let Some(nixl) = meta.backend_metadata.nixl_metadata { + (BackendType::Nixl, nixl.nixl_agent_metadata) + } else { + // Error: no metadata + return Err(...); + } + } + BackendType::TransferEngine => { + if let Some(te) = meta.backend_metadata.transfer_engine_metadata { + // Serialize TransferEngine metadata + (BackendType::TransferEngine, serialize_te_metadata(te)?) + } else { + return Err(...); + } + } + }; + + Self { + worker_rank: meta.worker_rank, + backend_type, + backend_metadata: metadata_bytes, + tensors: ... + } + } +} +``` + +### 2. Server-Side Storage Design + +#### βœ… **Recommendation: Store Backend Type in WorkerRecord** + +**Current Structure**: +```rust +pub struct WorkerRecord { + pub worker_rank: u32, + pub nixl_metadata: Vec, // Backend-agnostic name needed + pub tensors: Vec, +} +``` + +**Recommended Structure**: +```rust +pub struct WorkerRecord { + pub worker_rank: u32, + pub backend_type: BackendType, // NEW: Track backend type + pub backend_metadata: Vec, // RENAMED: Generic name (was nixl_metadata) + pub tensors: Vec, +} +``` + +**Rationale**: +- **Clarity**: `backend_metadata` is more accurate than `nixl_metadata` +- **Type Safety**: Backend type is explicit in storage layer +- **Query Support**: Can filter/query by backend type if needed + +#### ⚠️ **Concern: Storage Backend Compatibility** + +**Issue**: Redis/Kubernetes backends serialize `WorkerRecord`. How does PR 157 handle: +1. Existing stored data (only NIXL)? +2. Mixed deployments (some workers NIXL, some TransferEngine)? + +**Recommendations**: +1. **Default Backend Type**: When deserializing old data without `backend_type`, default to `BackendType::Nixl` +2. **Versioned Schema**: Consider adding a `schema_version` field for future migrations +3. **Validation**: Reject metadata where backend_type doesn't match metadata format + +**Example**: +```rust +impl From for WorkerRecord { + fn from(json: WorkerRecordJson) -> Self { + Self { + worker_rank: json.worker_rank, + backend_type: json.backend_type.unwrap_or(BackendType::Nixl), // Default for old data + backend_metadata: json.backend_metadata, // Was nixl_metadata + tensors: ... + } + } +} +``` + +### 3. Client-Side Backend Selection + +#### βœ… **Recommendation: Factory Pattern for Transfer Managers** + +**Current Approach**: +```python +class NixlTransferManager: + def __init__(self, agent_name: str, device_id: int): + ... +``` + +**Recommended Approach**: +```python +class TransferManagerFactory: + @staticmethod + def create( + backend_type: BackendType, + agent_name: str, + device_id: int, + **kwargs + ) -> TransferManager: + if backend_type == BackendType.NIXL: + return NixlTransferManager(agent_name, device_id) + elif backend_type == BackendType.TRANSFER_ENGINE: + return TransferEngineTransferManager( + agent_name, device_id, + seed_instance_ip=kwargs.get("seed_instance_ip"), + seed_instance_port=kwargs.get("seed_instance_port"), + ... + ) + else: + raise ValueError(f"Unsupported backend: {backend_type}") + +# Usage +metadata = get_metadata_from_server(model_name) +for worker in metadata.workers: + manager = TransferManagerFactory.create( + backend_type=worker.backend_type, + agent_name=f"worker_{worker.worker_rank}", + device_id=worker.worker_rank, + **extract_transfer_engine_config(worker.backend_metadata) + ) +``` + +**Rationale**: +- **Clean Separation**: Each backend has its own manager +- **Easy Testing**: Can mock individual backends +- **Configuration**: Backend-specific config passed via kwargs + +#### ⚠️ **Concern: Backend Availability Detection** + +**Issue**: How does the client know which backends are available at runtime? + +**Recommendations**: +1. **Runtime Detection**: Check for NIXL/TransferEngine availability (similar to `is_nixl_available()`) +2. **Fallback Strategy**: If preferred backend unavailable, fall back to alternative +3. **Error Messages**: Clear errors when required backend is missing + +**Example**: +```python +def select_backend(preferred: BackendType) -> BackendType: + """Select available backend with fallback.""" + if preferred == BackendType.TRANSFER_ENGINE: + if is_transfer_engine_available(): + return BackendType.TRANSFER_ENGINE + elif is_nixl_available(): + logger.warning("TransferEngine not available, falling back to NIXL") + return BackendType.NIXL + else: + raise RuntimeError("No transfer backend available") + elif preferred == BackendType.NIXL: + if is_nixl_available(): + return BackendType.NIXL + else: + raise RuntimeError("NIXL not available") + ... +``` + +### 4. Alignment with SGLang R-Fork + +#### βœ… **Recommendation: Match SGLang's Configuration Pattern** + +SGLang uses command-line arguments for TransferEngine configuration: +```bash +--load-format remote_instance +--remote-instance-weight-loader-backend transfer_engine +--remote-instance-weight-loader-seed-instance-ip +--remote-instance-weight-loader-seed-instance-service-port +``` + +**ModelExpress Equivalent**: +```python +# Environment variables or config +MX_TRANSFER_BACKEND=transfer_engine +MX_TRANSFER_ENGINE_SEED_IP= +MX_TRANSFER_ENGINE_SEED_PORT= +``` + +**Recommendations**: +1. **Consistent Naming**: Use similar parameter names to SGLang for familiarity +2. **Documentation**: Reference SGLang's R-Fork docs in ModelExpress docs +3. **Validation**: Validate that seed instance is reachable before publishing metadata + +### 5. Metadata Exchange & Routing + +#### βœ… **Recommendation: Backend-Aware Metadata Routing** + +**Issue**: When target receives metadata, it must route to correct backend. + +**Current Flow**: +``` +Target β†’ GetMetadata(model_name) β†’ Server β†’ Returns WorkerMetadata +Target β†’ Extract nixl_metadata β†’ Add remote NIXL agent +``` + +**Recommended Flow**: +``` +Target β†’ GetMetadata(model_name) β†’ Server β†’ Returns WorkerMetadata (with backend_type) +Target β†’ Check backend_type β†’ Route to appropriate manager: + - NIXL β†’ NixlTransferManager.add_remote_agent(nixl_metadata) + - TransferEngine β†’ TransferEngineTransferManager.connect(te_metadata) +``` + +**Implementation**: +```python +def load_model_from_source(model_name: str): + metadata = client.get_metadata(model_name) + + for worker in metadata.workers: + if worker.backend_type == BackendType.NIXL: + manager = get_nixl_manager(worker.worker_rank) + manager.add_remote_agent(worker.backend_metadata) + elif worker.backend_type == BackendType.TRANSFER_ENGINE: + manager = get_transfer_engine_manager(worker.worker_rank) + te_config = deserialize_transfer_engine_metadata(worker.backend_metadata) + manager.connect_to_seed(te_config) +``` + +### 6. Error Handling & Validation + +#### ⚠️ **Concerns** + +1. **Mismatched Backends**: What if source uses TransferEngine but target only has NIXL? +2. **Metadata Corruption**: Invalid backend_metadata for declared backend_type +3. **Connection Failures**: TransferEngine seed instance unreachable + +**Recommendations**: +1. **Validation**: Server should validate backend_type matches metadata format +2. **Error Messages**: Clear errors: "Source uses TransferEngine but target only supports NIXL" +3. **Fallback**: Consider automatic fallback if preferred backend unavailable (with user opt-in) + +**Example Validation**: +```rust +fn validate_worker_metadata(worker: &WorkerMetadata) -> Result<()> { + match worker.backend_type { + BackendType::Nixl => { + if worker.backend_metadata.is_empty() { + return Err("NIXL backend requires non-empty metadata"); + } + // Could also validate NIXL metadata format + } + BackendType::TransferEngine => { + let te_meta = deserialize_transfer_engine_metadata(&worker.backend_metadata)?; + if te_meta.seed_instance_ip.is_empty() { + return Err("TransferEngine requires seed_instance_ip"); + } + } + _ => return Err("Unsupported backend type"), + } + Ok(()) +} +``` + +### 7. Testing & Compatibility + +#### βœ… **Recommendations** + +1. **Unit Tests**: + - Test backend type serialization/deserialization + - Test migration from old format (nixl_metadata) to new format + - Test validation logic + +2. **Integration Tests**: + - Test NIXL-only deployment (backward compatibility) + - Test TransferEngine-only deployment + - Test mixed deployment (some workers NIXL, some TransferEngine) + +3. **Compatibility Tests**: + - Old client β†’ New server (should work) + - New client β†’ Old server (should handle gracefully) + +**Example Test**: +```rust +#[test] +fn test_backward_compatibility_old_nixl_metadata() { + // Simulate old WorkerMetadata with only nixl_metadata field + let old_meta = WorkerMetadata { + worker_rank: 0, + backend_type: BackendType::Unspecified, // Old format + nixl_metadata: vec![1, 2, 3, 4], // Old field + backend_metadata: None, // New field not set + tensors: vec![], + }; + + let record = WorkerRecord::from(old_meta); + assert_eq!(record.backend_type, BackendType::Nixl); // Auto-detected + assert_eq!(record.backend_metadata, vec![1, 2, 3, 4]); +} +``` + +## Specific PR Feedback Items + +### High Priority + +1. **Backward Compatibility**: Ensure existing NIXL-only deployments continue to work without changes +2. **Protocol Buffer Design**: Use `oneof` for backend-specific metadata (see Section 1) +3. **Storage Layer**: Rename `nixl_metadata` to `backend_metadata` and add `backend_type` field +4. **Validation**: Add server-side validation that backend_type matches metadata format + +### Medium Priority + +5. **Client Factory**: Implement factory pattern for transfer manager creation +6. **Error Handling**: Clear error messages for backend mismatches +7. **Documentation**: Update `docs/metadata.md` with TransferEngine backend information +8. **Configuration**: Align parameter names with SGLang's R-Fork for consistency + +### Low Priority + +9. **Future-Proofing**: Consider NCCL backend support (similar pattern) +10. **Observability**: Add metrics/logging for backend type usage +11. **Testing**: Comprehensive test coverage for all backend combinations + +## Questions for PR Author + +1. **Migration Strategy**: How are existing deployments migrated? Is there a migration script? +2. **Backend Selection**: How does the system decide which backend to use? User config or auto-detection? +3. **Mixed Deployments**: Can a single model have workers using different backends (e.g., worker 0 NIXL, worker 1 TransferEngine)? +4. **TransferEngine Implementation**: Is TransferEngine a separate library, or is it part of NIXL? What are the dependencies? +5. **Performance Comparison**: Are there benchmarks comparing NIXL vs TransferEngine performance? +6. **SGLang Integration**: Is this change intended to enable ModelExpress to work with SGLang's R-Fork feature? + +## Conclusion + +The addition of TransferEngine backend support is a valuable enhancement that aligns ModelExpress with SGLang's R-Fork capabilities. The key concerns are: + +1. **Design**: Use `oneof` for backend-specific metadata to ensure type safety and extensibility +2. **Compatibility**: Maintain backward compatibility with existing NIXL deployments +3. **Validation**: Ensure backend_type and metadata format are consistent +4. **Testing**: Comprehensive test coverage for all scenarios + +The recommended approach provides a clean, extensible design that can support additional backends (NCCL, custom) in the future while maintaining compatibility with existing deployments. + +--- + +## PR Review Comments + +This section provides specific comments to make directly on PR 157, organized by file and approximate line numbers. These comments should be added as inline code review comments on the PR. + +**Note**: These comments are based on the actual implementation in the `ishan/transfer-engine-backend` branch. The PR has already implemented the `oneof` pattern for backend metadata, which is excellent! The comments below address specific aspects of the implementation. + +### Protocol Buffer Changes + +#### File: `modelexpress_common/proto/p2p.proto` + +**Comment 1 - Line ~57-60 (WorkerMetadata message)** +``` +βœ… Excellent: Using `oneof` for backend metadata! + +Great implementation! The `oneof backend_metadata` pattern provides type safety +and clear separation. One suggestion: + +Consider adding a comment explaining the format of `transfer_engine_session_id`: +```protobuf +// TransferEngine: Mooncake session ID in format "ip:port" (e.g., "10.0.0.1:8000") +string transfer_engine_session_id = 10; +``` + +This helps users understand the expected format. Also, consider if a structured +message would be better for future extensibility (e.g., if you need to add +additional TransferEngine connection parameters later). +``` + +**Comment 2 - Line ~50-60 (WorkerMetadata message)** +``` +⚠️ Backward Compatibility Concern + +If the existing `bytes nixl_metadata = 2` field is being kept for compatibility, +please ensure: + +1. The field is marked as deprecated in comments +2. Server-side conversion handles both old and new formats +3. Auto-detection: If `backend_type` is unset but `nixl_metadata` is present, + infer `BACKEND_TYPE_NIXL` + +This is critical for existing deployments that won't be updated immediately. +``` + +**Comment 3 - Line ~92-97 (if BackendType enum is added)** +``` +βœ… Good: BackendType enum definition + +If adding a BackendType enum, ensure: +- `BACKEND_TYPE_UNSPECIFIED = 0` is the default (protobuf best practice) +- Values match the pattern used in SGLang's R-Fork for consistency +- Consider future-proofing with `BACKEND_TYPE_NCCL = 3` even if not implemented yet +``` + +**Comment 4 - Line ~103-109 (if TransferEngineBackendMetadata message is added)** +``` +πŸ“ Documentation Suggestion + +The TransferEngineBackendMetadata message should include: +- `seed_instance_ip`: IP address of seed instance (required) +- `seed_instance_service_port`: HTTP service port (required) +- `send_weights_group_ports`: For NCCL backend variant (optional, repeated) +- Comments explaining each field's purpose + +Consider aligning field names with SGLang's R-Fork parameters for familiarity: +- `--remote-instance-weight-loader-seed-instance-ip` +- `--remote-instance-weight-loader-seed-instance-service-port` +``` + +### Server-Side Rust Changes + +#### File: `modelexpress_server/src/metadata_backend.rs` + +**Comment 5 - Line ~64-68 (WorkerRecord struct)** +``` +βœ… Excellent: Using `BackendMetadataRecord` enum! + +Great design! The `BackendMetadataRecord` enum provides type safety and makes +the backend type explicit. The implementation looks clean. + +One observation: The `BackendMetadataRecord::None` variant (line 43) - is this +intentionally allowed? If a worker has no backend metadata, should we reject +it during validation, or is this for a specific use case? Consider adding +validation in `publish_metadata` to ensure at least one backend is provided. +``` + +**Comment 6 - Line ~81-96 (From for WorkerRecord)** +``` +βœ… Clean Implementation: Conversion logic looks good! + +The conversion from `WorkerMetadata` to `WorkerRecord` correctly handles the +`oneof` pattern. One suggestion: + +Consider adding validation to ensure at least one backend metadata is provided: +```rust +impl From for WorkerRecord { + fn from(meta: WorkerMetadata) -> Self { + use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata; + let backend_metadata = match meta.backend_metadata { + Some(BackendMetadata::NixlMetadata(data)) => { + if data.is_empty() { + tracing::warn!("Empty NIXL metadata for worker {}", meta.worker_rank); + } + BackendMetadataRecord::Nixl(data) + } + Some(BackendMetadata::TransferEngineSessionId(sid)) => { + if sid.is_empty() { + tracing::warn!("Empty TransferEngine session ID for worker {}", meta.worker_rank); + } + BackendMetadataRecord::TransferEngine(sid) + } + None => { + tracing::warn!("No backend metadata provided for worker {}", meta.worker_rank); + BackendMetadataRecord::None + } + }; + ... + } +} +``` + +This helps catch configuration errors early. +``` + +**Comment 7 - Line ~77-88 (From for WorkerMetadata)** +``` +πŸ”„ Conversion Logic: Ensure bidirectional conversion works + +The reverse conversion `From for WorkerMetadata` must: +1. Set `backend_type` field correctly +2. Populate the appropriate `oneof` field based on `backend_type` +3. Handle legacy `nixl_metadata` field for backward compatibility + +This ensures targets can correctly deserialize and route to the right backend. +``` + +#### File: `modelexpress_server/src/p2p_service.rs` + +**Comment 8 - Line ~49-59 (BackendMetadataRecord::from_flat)** +``` +⚠️ Priority Logic: TransferEngine takes priority + +The `from_flat` method gives TransferEngine priority when both `nixl_metadata` +and `transfer_engine_session_id` are present (line 50-53). This is reasonable, +but consider: + +1. **Documentation**: Add a comment explaining why TransferEngine takes priority +2. **Validation**: Should we warn or error if both are provided? It might indicate + a configuration mistake +3. **Consistency**: Ensure this priority is consistent across all code paths + +Suggestion: +```rust +pub fn from_flat(nixl_metadata: Vec, transfer_engine_session_id: Option) -> Self { + if let Some(sid) = transfer_engine_session_id + && !sid.is_empty() + { + // TransferEngine takes priority when both are present + if !nixl_metadata.is_empty() { + tracing::warn!( + "Both NIXL and TransferEngine metadata provided, using TransferEngine" + ); + } + return Self::TransferEngine(sid); + } + ... +} +``` +``` + +**Comment 9 - Line ~84-119 (get_metadata implementation)** +``` +πŸ“Š Logging Enhancement + +When returning metadata, log the backend types being returned: +```rust +info!( + "Found metadata for model '{}': {} workers (backends: {:?}), {} tensors", + req.model_name, + record.workers.len(), + record.workers.iter().map(|w| w.backend_type).collect::>(), + total_tensors +); +``` + +This helps with debugging mixed-backend deployments. +``` + +### Client-Side Python Changes + +#### File: `modelexpress_client/python/modelexpress/nixl_transfer.py` (or new file) + +**Comment 10 - Line ~1-50 (if creating TransferEngineTransferManager)** +``` +🏭 Factory Pattern Suggestion + +Consider creating a factory for transfer managers to handle backend selection: + +```python +class TransferManagerFactory: + @staticmethod + def create( + backend_type: BackendType, + agent_name: str, + device_id: int, + **kwargs + ) -> TransferManager: + if backend_type == BackendType.NIXL: + return NixlTransferManager(agent_name, device_id) + elif backend_type == BackendType.TRANSFER_ENGINE: + return TransferEngineTransferManager( + agent_name, device_id, + seed_instance_ip=kwargs.get("seed_instance_ip"), + seed_instance_port=kwargs.get("seed_instance_port"), + ) + else: + raise ValueError(f"Unsupported backend: {backend_type}") +``` + +This provides clean separation and makes testing easier. +``` + +**Comment 11 - Line ~37-40 (is_nixl_available function)** +``` +πŸ” Backend Availability Detection + +Add a similar function for TransferEngine: +```python +def is_transfer_engine_available() -> bool: + """Check if TransferEngine is available.""" + try: + # Import TransferEngine library + from transfer_engine import TransferEngine + return True + except ImportError: + return False +``` + +Also consider a backend selection function with fallback: +```python +def select_available_backend(preferred: BackendType) -> BackendType: + """Select available backend with fallback.""" + if preferred == BackendType.TRANSFER_ENGINE: + if is_transfer_engine_available(): + return BackendType.TRANSFER_ENGINE + elif is_nixl_available(): + logger.warning("TransferEngine not available, falling back to NIXL") + return BackendType.NIXL + ... +``` +``` + +#### File: `modelexpress_client/python/modelexpress/` (vLLM loader integration) + +**Comment 12 - Line ~TBD (where metadata is consumed)** +``` +πŸ”„ Backend-Aware Routing Required + +When target receives metadata from `get_metadata()`, ensure it routes to the +correct backend based on `backend_type`: + +```python +def load_model_from_source(model_name: str): + metadata = client.get_metadata(model_name) + + for worker in metadata.workers: + if worker.backend_type == BackendType.NIXL: + manager = get_nixl_manager(worker.worker_rank) + manager.add_remote_agent(worker.backend_metadata) + elif worker.backend_type == BackendType.TRANSFER_ENGINE: + manager = get_transfer_engine_manager(worker.worker_rank) + te_config = deserialize_transfer_engine_metadata(worker.backend_metadata) + manager.connect_to_seed(te_config) + else: + raise ValueError(f"Unsupported backend: {worker.backend_type}") +``` + +This ensures targets can handle sources using different backends. +``` + +**Comment 13 - Line ~TBD (error handling)** +``` +⚠️ Error Handling: Backend Mismatch + +Add clear error handling when source and target backends don't match: + +```python +if worker.backend_type == BackendType.TRANSFER_ENGINE: + if not is_transfer_engine_available(): + raise RuntimeError( + f"Source worker {worker.worker_rank} uses TransferEngine backend, " + "but TransferEngine is not available on this target. " + "Please install TransferEngine or use a source with NIXL backend." + ) +``` + +Provide actionable error messages to help users resolve issues. +``` + +### Storage Backend Changes + +#### File: `modelexpress_server/src/metadata_backend/redis.rs` + +**Comment 14 - Line ~125-147 (WorkerRecordJson)** +``` +πŸ”„ JSON Serialization: Handle backend_type field + +The `WorkerRecordJson` struct needs to include `backend_type`: + +```rust +#[derive(Debug, Clone, Serialize, Deserialize)] +struct WorkerRecordJson { + pub worker_rank: u32, + pub backend_type: Option, // NEW: Optional for backward compat + pub backend_metadata: Vec, // RENAMED from nixl_metadata + pub tensors: Vec, +} +``` + +In `From for WorkerRecord`, default to `BackendType::Nixl` +if `backend_type` is `None` (for old stored data). +``` + +#### File: `modelexpress_server/src/metadata_backend/kubernetes.rs` + +**Comment 15 - Line ~TBD (WorkerStatus in k8s_types.rs)** +``` +πŸ“ Kubernetes CRD: Add backend_type field + +The `WorkerStatus` struct in `k8s_types.rs` should include: +```rust +pub struct WorkerStatus { + pub worker_rank: i32, + pub backend_type: Option, // "nixl", "transfer_engine", etc. + pub nixl_metadata: String, // Consider renaming to backend_metadata + ... +} +``` + +Update the CRD schema in `examples/p2p_transfer_k8s/deploy/persistence/crd-modelmetadata.yaml` +to include the backend_type field. +``` + +### Testing + +#### File: `modelexpress_server/src/metadata_backend.rs` (test module) + +**Comment 16 - Line ~TBD (add new tests)** +``` +βœ… Test Coverage Needed + +Please add tests for: +1. **Backward compatibility**: Old WorkerMetadata with only `nixl_metadata` field +2. **New format**: WorkerMetadata with `backend_type` and `oneof` fields +3. **Migration**: Conversion from old to new format +4. **Validation**: Reject invalid backend_type/metadata combinations +5. **Mixed deployments**: Model with some workers NIXL, some TransferEngine + +Example: +```rust +#[test] +fn test_backward_compatibility_old_nixl_metadata() { + let old_meta = WorkerMetadata { + worker_rank: 0, + backend_type: BackendType::Unspecified, + nixl_metadata: vec![1, 2, 3, 4], // Old field + backend_metadata: None, // New field not set + tensors: vec![], + }; + + let record = WorkerRecord::from(old_meta); + assert_eq!(record.backend_type, BackendType::Nixl); // Auto-detected +} +``` +``` + +### Documentation + +#### File: `docs/metadata.md` + +**Comment 17 - Line ~1-10 (Overview section)** +``` +πŸ“ Documentation Update Needed + +Please update the overview to mention TransferEngine backend support: + +```markdown +## Overview + +ModelExpress P2P transfers require coordination between source and target instances: +1. **Source** publishes transfer backend metadata (NIXL agent info or TransferEngine + connection info + tensor descriptors) after loading model weights +2. **Target** queries for source metadata to establish connections (RDMA for NIXL, + TransferEngine connection for TransferEngine backend) +3. **Coordination** signals ensure targets wait for sources to be fully ready +``` + +Also add a new section explaining TransferEngine backend usage and configuration. +``` + +**Comment 18 - Line ~TBD (add TransferEngine section)** +``` +πŸ“š New Section: TransferEngine Backend + +Add a section explaining: +1. When to use TransferEngine vs NIXL +2. Configuration parameters (align with SGLang R-Fork) +3. Example usage +4. Troubleshooting common issues + +Reference: https://raw.githubusercontent.com/sgl-project/sglang/main/docs/advanced_features/rfork.md +``` + +### Configuration & Environment Variables + +#### File: `README.md` or new config documentation + +**Comment 19 - Line ~TBD** +``` +βš™οΈ Configuration Documentation + +Document the new environment variables for TransferEngine: +- `MX_TRANSFER_BACKEND`: Backend type (`nixl`, `transfer_engine`, default: `nixl`) +- `MX_TRANSFER_ENGINE_SEED_IP`: Seed instance IP (required for TransferEngine) +- `MX_TRANSFER_ENGINE_SEED_PORT`: Seed instance service port (required for TransferEngine) + +Align naming with SGLang's R-Fork parameters for consistency. +``` + +### Additional Comments Based on Actual Implementation + +#### File: `modelexpress_common/proto/p2p.proto` + +**Comment 20 - Line ~59 (transfer_engine_session_id field)** +``` +πŸ“ Format Documentation Needed + +The `transfer_engine_session_id` is described as "ip:port" format. Consider: + +1. **Validation**: Add format validation (e.g., regex or parsing) to ensure it's + a valid "ip:port" format +2. **Documentation**: Add example in comment: `// Format: "10.0.0.1:8000"` +3. **Future-proofing**: If you need additional TransferEngine connection parameters + later (e.g., authentication tokens, protocol version), consider using a structured + message instead of a string + +Current approach is fine for MVP, but structured message would be more extensible: +```protobuf +message TransferEngineBackendMetadata { + string seed_instance_ip = 1; + uint32 seed_instance_service_port = 2; + // Future: repeated uint32 send_weights_group_ports = 3; +} +``` +``` + +#### File: `modelexpress_server/src/metadata_backend.rs` + +**Comment 21 - Line ~43 (BackendMetadataRecord::None)** +``` +❓ Design Question: When is `None` valid? + +The `BackendMetadataRecord::None` variant suggests workers can exist without +backend metadata. Is this intentional? Consider: + +1. **Use case**: When would a worker have no backend metadata? Is this for + a specific deployment scenario? +2. **Validation**: Should `publish_metadata` reject workers with `None` backend? +3. **Documentation**: Add a comment explaining when `None` is acceptable + +If `None` is not a valid state, consider removing it and making the enum +non-optional, or add validation to reject it. +``` + +**Comment 22 - Line ~49-59 (from_flat priority logic)** +``` +βœ… Good: Priority logic is clear + +The priority logic (TransferEngine > NIXL > None) is reasonable. One enhancement: + +Consider logging when priority is applied to help with debugging: +```rust +pub fn from_flat(nixl_metadata: Vec, transfer_engine_session_id: Option) -> Self { + let has_nixl = !nixl_metadata.is_empty(); + let has_te = transfer_engine_session_id.as_ref() + .map(|s| !s.is_empty()) + .unwrap_or(false); + + if has_te && has_nixl { + tracing::debug!( + "Both NIXL and TransferEngine metadata present, using TransferEngine (priority)" + ); + } + + if let Some(sid) = transfer_engine_session_id + && !sid.is_empty() + { + return Self::TransferEngine(sid); + } + ... +} +``` +``` + +### Summary of Priority Comments + +**High Priority (Must Address)**: +- Comment 2: Backward compatibility handling (if old format still exists) +- Comment 8: Priority logic documentation and validation +- Comment 20: TransferEngine session ID format validation +- Comment 21: Clarify when `BackendMetadataRecord::None` is valid + +**Medium Priority (Should Address)**: +- Comment 5: Validation for empty backend metadata +- Comment 6: Add validation/warnings for empty metadata +- Comment 10: Factory pattern for transfer managers (client-side) +- Comment 12: Backend-aware routing in client +- Comment 16: Test coverage for all scenarios +- Comment 17: Documentation updates + +**Low Priority (Nice to Have)**: +- Comment 1: Enhanced documentation for TransferEngine session ID format +- Comment 9: Enhanced logging +- Comment 11: Backend availability detection with fallback +- Comment 19: Configuration documentation +- Comment 22: Enhanced logging for priority logic + +--- + +## Backend Selection Logic: TransferEngine vs NIXL + +### Current State + +Based on the codebase review, **the backend selection logic is not yet fully implemented**. Here's what I found: + +1. **Protocol Support**: The `p2p.proto` file supports both backends via `oneof`: + ```protobuf + oneof backend_metadata { + bytes nixl_metadata = 2; + string transfer_engine_session_id = 10; + } + ``` + +2. **Client Implementation**: Currently, the client code (`vllm_loader.py` line 338) **only sets NIXL metadata**: + ```python + worker = p2p_pb2.WorkerMetadata( + worker_rank=device_id, + nixl_metadata=nixl_metadata, # Only NIXL is set + tensors=tensor_protos, + ) + ``` + +3. **No Selection Logic**: There's no configuration or code that chooses between TransferEngine and NIXL. + +### How It Should Work + +The backend selection should happen at **two points**: + +#### 1. Source Side (When Publishing Metadata) + +The source decides which backend to use based on: +- **Configuration**: Environment variable or config file +- **Availability**: Runtime detection of which backends are available +- **User preference**: Explicit configuration + +**Recommended Implementation**: +```python +# In vllm_loader.py _publish_metadata_to_server() +def _publish_metadata_to_server(self, raw_tensors, device_id): + # Determine which backend to use + backend_type = self._select_backend() # NEW: Selection logic + + if backend_type == "transfer_engine": + # Initialize TransferEngine and get session ID + te_session_id = self._get_transfer_engine_session_id() + worker = p2p_pb2.WorkerMetadata( + worker_rank=device_id, + transfer_engine_session_id=te_session_id, # Set TE field + tensors=tensor_protos, + ) + else: # Default to NIXL + nixl_metadata = self._nixl_manager.nixl_metadata if self._nixl_manager else b"" + worker = p2p_pb2.WorkerMetadata( + worker_rank=device_id, + nixl_metadata=nixl_metadata, # Set NIXL field + tensors=tensor_protos, + ) + + self._mx_client.publish_metadata(model_name, [worker]) + +def _select_backend(self) -> str: + """Select backend based on configuration and availability.""" + # Check explicit configuration + configured_backend = os.environ.get("MX_TRANSFER_BACKEND", "nixl") + + if configured_backend == "transfer_engine": + if is_transfer_engine_available(): + return "transfer_engine" + else: + logger.warning("TransferEngine not available, falling back to NIXL") + return "nixl" + else: + return "nixl" # Default +``` + +#### 2. Target Side (When Receiving Metadata) + +The target must use **whatever backend the source published**. The target cannot choose - it must match the source's backend. + +**Recommended Implementation**: +```python +# In vllm_loader.py load_model() for target +def load_model(self, ...): + # Get metadata from server + metadata_response = self._mx_client.get_metadata(model_name) + + for worker in metadata_response.workers: + # Check which backend the source used + if worker.HasField("transfer_engine_session_id"): + # Source uses TransferEngine + if not is_transfer_engine_available(): + raise RuntimeError( + f"Source worker {worker.worker_rank} uses TransferEngine, " + "but TransferEngine is not available on this target" + ) + # Use TransferEngine to connect + self._connect_via_transfer_engine(worker.transfer_engine_session_id) + + elif worker.HasField("nixl_metadata"): + # Source uses NIXL + if not is_nixl_available(): + raise RuntimeError( + f"Source worker {worker.worker_rank} uses NIXL, " + "but NIXL is not available on this target" + ) + # Use NIXL to connect + self._connect_via_nixl(worker.nixl_metadata) + else: + raise RuntimeError("Source worker has no backend metadata") +``` + +### Configuration Options + +**Recommended Environment Variables**: + +1. **`MX_TRANSFER_BACKEND`**: Primary backend selection + - Values: `nixl` (default), `transfer_engine`, `auto` + - `auto`: Try TransferEngine first, fallback to NIXL + +2. **`MX_TRANSFER_ENGINE_ENABLED`**: Explicit enable/disable + - Values: `true`, `false` (default: `false`) + - Overrides `MX_TRANSFER_BACKEND` if set to `false` + +3. **Runtime Detection**: Check availability at runtime + ```python + def is_transfer_engine_available() -> bool: + try: + from transfer_engine import TransferEngine + return True + except ImportError: + return False + ``` + +### Priority/Precedence Rules + +1. **Source publishes with one backend** β†’ Target must use the same backend +2. **If source uses TransferEngine but target doesn't have it** β†’ Error (clear message) +3. **If source uses NIXL but target doesn't have it** β†’ Error (clear message) +4. **If both are available** β†’ Use source's choice (no negotiation) + +### Missing Implementation + +Based on the code review, the following is **missing**: + +1. βœ… **Proto support**: Already implemented (`oneof` pattern) +2. ❌ **Source selection logic**: Not implemented (always uses NIXL) +3. ❌ **Target routing logic**: Not implemented (always expects NIXL) +4. ❌ **TransferEngine client code**: Not implemented +5. ❌ **Configuration variables**: Not documented/implemented +6. ❌ **Availability detection**: Not implemented + +### Recommendation + +Add explicit backend selection logic to the client code: + +1. **Add configuration**: `MX_TRANSFER_BACKEND` environment variable +2. **Add selection method**: `_select_backend()` in `MxSourceModelLoader` +3. **Add routing method**: Check `HasField()` in `MxTargetModelLoader` +4. **Add TransferEngine manager**: Similar to `NixlTransferManager` +5. **Add tests**: Test both backends and mixed scenarios + +This ensures the backend selection is **explicit and configurable**, rather than implicit or hardcoded. diff --git a/docs/feedback_pr19920.md b/docs/feedback_pr19920.md new file mode 100644 index 00000000..f97f9538 --- /dev/null +++ b/docs/feedback_pr19920.md @@ -0,0 +1,863 @@ +# PR 19920: [1/2] Add ModelExpress coordination for remote instance weight loading - matching TP + +## Executive Summary + +This document provides a design review and feedback for SGLang PR 19920, which adds ModelExpress coordination for remote instance weight loading. The PR integrates ModelExpress gRPC server as a coordination layer for TransferEngine-based weight transfers, replacing direct HTTP communication between seed and target instances. + +**Key Changes:** +- Adds `MODEL_EXPRESS` backend option for `remote_instance_weight_loader_backend` +- Integrates ModelExpress client for metadata coordination +- Supports TP rank matching between seed and target instances +- Uses TransferEngine for actual RDMA transfers (coordinated via ModelExpress) + +## Architecture Overview + +### Current Flow (Before PR 19920) + +**NCCL Backend:** +``` +Seed Instance β†’ Direct HTTP β†’ Target Instance + - Seed publishes TransferEngine session ID via HTTP endpoint + - Target queries seed HTTP endpoint for session ID + - Target connects directly to seed via TransferEngine +``` + +**TransferEngine Backend (Direct):** +``` +Seed Instance β†’ HTTP endpoint β†’ Target Instance + - Seed exposes /get_remote_instance_transfer_engine_info + - Target queries per-rank session IDs + - Direct TransferEngine connection +``` + +### New Flow (After PR 19920) + +**ModelExpress Backend:** +``` +Seed Instance β†’ ModelExpress Server β†’ Target Instance + - Seed publishes metadata to ModelExpress gRPC server + - Target queries ModelExpress for seed metadata + - ModelExpress coordinates ready state + - Target connects to seed via TransferEngine (using session ID from metadata) +``` + +## Implementation Review + +### 1. Protocol Buffer Integration + +#### βœ… **Good: Correct Use of `oneof` Pattern** + +**File**: `python/sglang/srt/model_loader/loader.py` (line ~2340) + +The implementation correctly uses the `oneof` pattern to extract TransferEngine session ID: + +```python +backend_field = source_worker.WhichOneof("backend_metadata") +if backend_field == "transfer_engine_session_id": + seed_session_id = source_worker.transfer_engine_session_id +else: + raise RuntimeError( + f"ModelExpress: expected transfer_engine_session_id, " + f"got backend_metadata={backend_field}" + ) +``` + +**Comment**: This correctly handles the `oneof` pattern from ModelExpress PR 157. Good error handling when the wrong backend type is present. + +#### ⚠️ **Concern: No Fallback for NIXL Backend** + +**Issue**: The code only handles `transfer_engine_session_id` and raises an error for other backends. What if the source uses NIXL backend? + +**Recommendation**: Add support for NIXL backend or provide a clear error message: + +```python +backend_field = source_worker.WhichOneof("backend_metadata") +if backend_field == "transfer_engine_session_id": + seed_session_id = source_worker.transfer_engine_session_id +elif backend_field == "nixl_metadata": + raise RuntimeError( + f"ModelExpress: source worker {tp_rank} uses NIXL backend, " + f"but MODEL_EXPRESS backend requires TransferEngine. " + f"Please use a source with TransferEngine backend or use NIXL directly." + ) +else: + raise RuntimeError( + f"ModelExpress: unknown backend_metadata={backend_field} " + f"for worker {tp_rank}" + ) +``` + +### 2. Source Side: Publishing Metadata + +#### βœ… **Good: Proper Metadata Publishing** + +**File**: `python/sglang/srt/model_executor/model_runner.py` (line ~680-750) + +The `_publish_model_express_metadata()` function: +- Correctly builds tensor descriptors from weight info +- Uses `transfer_engine_session_id` in the `oneof` field +- Publishes both metadata and ready flag +- Handles element size to dtype mapping for FP8 models + +**Comment**: The implementation correctly uses byte sizes (`numel * element_size`) for tensor descriptors, which is important for mixed-dtype models (FP8 + BF16). + +#### ⚠️ **Concern: Dtype Inference from Element Size** + +**File**: `python/sglang/srt/model_executor/model_runner.py` (line ~700) + +```python +element_size_to_dtype = {1: "float8_e4m3fn", 2: "bfloat16", 4: "float32", 8: "float64"} +``` + +**Issue**: This mapping is lossy. Multiple dtypes can have the same element size: +- Element size 2: `float16`, `bfloat16`, `int16`, `uint16` +- Element size 1: `int8`, `uint8`, `float8_e4m3fn`, `float8_e5m2` + +**Recommendation**: Use actual tensor dtype instead of inferring from element size: + +```python +tensors = [] +for name, (addr, numel, element_size) in weight_info.items(): + # Get actual tensor to determine dtype + tensor = dict(model.named_parameters())[name] + dtype_str = str(tensor.dtype).replace("torch.", "") + + tensors.append(p2p_pb2.TensorDescriptor( + name=name, + addr=addr, + size=numel * element_size, + device_id=self.gpu_id, + dtype=dtype_str, # Use actual dtype + )) +``` + +**Alternative**: If weight_info doesn't include tensor references, add dtype to the weight_info tuple: +```python +# In register_memory_region, return (addr, numel, element_size, dtype_str) +weight_info[name] = (addr, numel, element_size, str(tensor.dtype).replace("torch.", "")) +``` + +#### βœ… **Good: TP Rank Matching** + +**File**: `python/sglang/srt/model_loader/loader.py` (line ~2310) + +The code correctly matches TP ranks: +```python +for w in response.workers: + if w.worker_rank == tp_rank: + source_worker = w + break +``` + +This ensures each target TP rank connects to the corresponding seed TP rank, which is critical for tensor parallelism. + +### 3. Target Side: Loading Weights + +#### βœ… **Good: Byte Size Matching** + +**File**: `python/sglang/srt/model_loader/loader.py` (line ~2370) + +The code correctly uses byte sizes for matching: +```python +seed_ptr, seed_size = weight_info +local_size = tensor.numel() * tensor.element_size() +if seed_size != local_size: + raise RuntimeError(...) +``` + +**Comment**: This is correct! RDMA is a memcpy operation, so byte size matching is sufficient. Dtype differences (e.g., FP8 vs BF16) are handled by the model's quantization logic, not the transfer layer. + +#### ⚠️ **Concern: Missing Tensor Name Validation** + +**Issue**: The code assumes tensor names match exactly between seed and target. What if: +- Model architectures differ slightly? +- Tensor names have different prefixes? +- Some tensors are missing? + +**Recommendation**: Add more robust matching: + +```python +for name, tensor in model.named_parameters(): + weight_info = seed_weight_info.get(name, None) + if weight_info is None: + # Try fuzzy matching or provide helpful error + logger.warning( + f"ModelExpress: tensor '{name}' not found in seed metadata. " + f"Available tensors: {list(seed_weight_info.keys())[:10]}..." + ) + raise RuntimeError( + f"ModelExpress: cannot find weight info for {name} " + f"in seed metadata. This may indicate a model architecture mismatch." + ) +``` + +#### βœ… **Good: Ready State Coordination** + +**File**: `python/sglang/srt/model_loader/loader.py` (line ~2280) + +The code correctly waits for seed ready state: +```python +ready, session_id, metadata_hash = mx_client.wait_for_ready( + model_name, worker_id=tp_rank, +) +``` + +This ensures the target doesn't start transferring before the seed is fully initialized and stable. + +### 4. Configuration & CLI Arguments + +#### βœ… **Good: Clear CLI Arguments** + +**File**: `python/sglang/srt/server_args.py` + +The PR adds three new CLI arguments: +- `--model-express-url`: ModelExpress server URL +- `--model-express-model-name`: Model name for coordination +- `--model-express-source`: Flag to run as seed source + +**Comment**: The arguments are well-named and follow SGLang's existing patterns. + +#### ⚠️ **Concern: Validation Logic** + +**File**: `python/sglang/srt/server_args.py` (line ~2722) + +```python +if self.remote_instance_weight_loader_backend == "model_express": + if self.model_express_url is None: + logger.warning("Fallback load_format to 'auto'...") + self.load_format = "auto" +``` + +**Issue**: The validation silently falls back to `auto` instead of raising an error. This could lead to confusion. + +**Recommendation**: Make validation stricter or provide clearer messaging: + +```python +if self.remote_instance_weight_loader_backend == "model_express": + if self.model_express_url is None: + raise ValueError( + "--model-express-url is required when using " + "--remote-instance-weight-loader-backend=model_express" + ) + if not self.validate_transfer_engine(): + raise ValueError( + "TransferEngine is required for model_express backend. " + "Please install mooncake.engine or use a different backend." + ) +``` + +#### ⚠️ **Concern: Model Name Default** + +**File**: `python/sglang/srt/model_executor/model_runner.py` (line ~685) + +```python +model_name = ( + self.server_args.model_express_model_name + or self.server_args.model_path +) +``` + +**Issue**: Using `model_path` as default could lead to inconsistent model names (e.g., `/path/to/model` vs `meta-llama/Llama-3.1-70B`). + +**Recommendation**: Use a more consistent default or require explicit model name: + +```python +model_name = self.server_args.model_express_model_name +if not model_name: + # Extract model name from model_path (e.g., last component) + model_name = os.path.basename(self.server_args.model_path.rstrip('/')) + logger.warning( + f"ModelExpress: using model_name='{model_name}' from model_path. " + f"Consider setting --model-express-model-name explicitly." + ) +``` + +### 5. Error Handling + +#### βœ… **Good: Comprehensive Error Messages** + +The code provides clear error messages for common failure modes: +- Missing metadata +- Worker rank mismatch +- Size mismatches +- TransferEngine failures + +#### ⚠️ **Concern: Timeout Handling** + +**File**: `python/sglang/srt/model_loader/loader.py` (line ~2280) + +```python +ready, session_id, metadata_hash = mx_client.wait_for_ready( + model_name, worker_id=tp_rank, +) +if not ready: + raise RuntimeError("ModelExpress: timed out waiting for seed ready...") +``` + +**Issue**: The timeout is not configurable and may not be visible in the error message. + +**Recommendation**: Add timeout parameter and include it in error: + +```python +timeout_seconds = load_config.model_express_ready_timeout or 7200 # 2 hours default +ready, session_id, metadata_hash = mx_client.wait_for_ready( + model_name, worker_id=tp_rank, timeout_seconds=timeout_seconds, +) +if not ready: + raise RuntimeError( + f"ModelExpress: timed out waiting for seed ready " + f"(model={model_name}, worker={tp_rank}, timeout={timeout_seconds}s)" + ) +``` + +### 6. Integration with TransferEngine + +#### βœ… **Good: Reuses Existing TransferEngine Infrastructure** + +The PR correctly reuses: +- `register_memory_region()` for memory registration +- `batch_transfer_sync_read()` for RDMA transfers +- Existing TransferEngine initialization logic + +**Comment**: This is a clean integration that doesn't duplicate code. + +#### ⚠️ **Concern: TransferEngine Initialization Timing** + +**File**: `python/sglang/srt/model_executor/model_runner.py` (line ~1075) + +For seed sources, TransferEngine weight info is registered in `model_specific_adjustment()`: + +```python +if self.server_args.model_express_source: + if self.remote_instance_transfer_engine_weight_info is None: + self.remote_instance_transfer_engine_weight_info = ( + register_memory_region(self.model, self.remote_instance_transfer_engine) + ) + self._publish_model_express_metadata() +``` + +**Issue**: This happens after model loading. If the model is loaded via `DefaultModelLoader` (load_format=auto), the weights may have been processed/quantized, which could affect memory addresses. + +**Recommendation**: Document this timing and ensure weights are stable before registration: + +```python +# Ensure model weights are finalized before registering +# (post_load_weights may modify weights) +if hasattr(self.model, "post_load_weights"): + self.model.post_load_weights() + +# Now register memory regions (weights are stable) +if self.server_args.model_express_source: + ... +``` + +### 7. Testing & Edge Cases + +#### ❓ **Missing: Test Coverage** + +**Questions**: +1. Are there unit tests for `load_model_from_model_express()`? +2. Are there integration tests for the full flow (seed β†’ ModelExpress β†’ target)? +3. How is TP rank mismatch handled? +4. What happens if seed and target have different TP sizes? + +**Recommendation**: Add tests for: +- TP rank matching logic +- Byte size validation +- Missing tensor handling +- ModelExpress server unavailability +- Timeout scenarios + +### 8. Documentation + +#### ⚠️ **Missing: Usage Documentation** + +**Recommendation**: Add documentation explaining: +1. How to set up ModelExpress server +2. How to run seed instance with `--model-express-source` +3. How to run target instance with `--remote-instance-weight-loader-backend=model_express` +4. Model name coordination requirements +5. TP rank matching requirements + +**Example**: +```markdown +## ModelExpress Remote Instance Loading + +### Setup + +1. Start ModelExpress server: + ```bash + modelexpress-server --port 8001 + ``` + +2. Start seed instance: + ```bash + python -m sglang.launch_server \ + --model-path meta-llama/Llama-3.1-70B \ + --model-express-url localhost:8001 \ + --model-express-model-name meta-llama/Llama-3.1-70B \ + --model-express-source \ + --remote-instance-weight-loader-start-seed-via-transfer-engine + ``` + +3. Start target instance: + ```bash + python -m sglang.launch_server \ + --model-path meta-llama/Llama-3.1-70B \ + --load-format remote_instance \ + --remote-instance-weight-loader-backend model_express \ + --model-express-url localhost:8001 \ + --model-express-model-name meta-llama/Llama-3.1-70B + ``` + +### Requirements + +- Seed and target must have **matching TP sizes** (e.g., both TP=8) +- Each target TP rank connects to the corresponding seed TP rank +- ModelExpress server must be accessible from both instances +- TransferEngine must be initialized on both instances +``` + +## Specific PR Review Comments + +### High Priority + +1. **Dtype Inference**: Fix dtype mapping to use actual tensor dtypes instead of element size (see Section 2) +2. **NIXL Backend Support**: Add error handling for NIXL backend case (see Section 1) +3. **Validation**: Make CLI argument validation stricter (see Section 4) +4. **Model Name Default**: Improve model name default logic (see Section 4) + +### Medium Priority + +5. **Tensor Name Matching**: Add more robust tensor name matching with better error messages (see Section 3) +6. **Timeout Configuration**: Make timeout configurable and visible in errors (see Section 5) +7. **Memory Registration Timing**: Document/ensure weights are stable before registration (see Section 6) +8. **Documentation**: Add usage documentation (see Section 8) + +### Low Priority + +9. **Test Coverage**: Add comprehensive tests (see Section 7) +10. **Logging**: Add more detailed logging for debugging +11. **Error Recovery**: Consider retry logic for transient ModelExpress errors + +## Alignment with ModelExpress PR 157 + +### βœ… **Correct Integration** + +The SGLang PR correctly uses the `oneof` pattern from ModelExpress PR 157: +- Extracts `transfer_engine_session_id` from `backend_metadata` oneof +- Uses `WhichOneof()` to check backend type +- Provides appropriate error handling + +### ⚠️ **Missing: Backend Selection** + +The SGLang PR assumes TransferEngine backend. It doesn't: +- Check if source uses NIXL backend +- Provide fallback to NIXL if TransferEngine unavailable +- Allow configuration of preferred backend + +**Recommendation**: Consider adding backend selection logic similar to what was discussed in ModelExpress PR 157 feedback. + +## Conclusion + +PR 19920 provides a solid integration of ModelExpress coordination for remote instance weight loading. The implementation correctly: + +1. βœ… Uses the `oneof` pattern from ModelExpress PR 157 +2. βœ… Implements TP rank matching +3. βœ… Handles byte-size matching for mixed-dtype models +4. βœ… Coordinates ready state via ModelExpress + +**Key Improvements Needed**: +1. Fix dtype inference to use actual tensor dtypes +2. Add NIXL backend error handling +3. Improve validation and error messages +4. Add comprehensive documentation and tests + +The PR is well-structured and follows SGLang's existing patterns. With the suggested improvements, it will provide a robust foundation for ModelExpress-coordinated weight loading. + +--- + +## PR Review Comments + +This section provides specific comments to make directly on PR 19920, organized by file and line numbers. These comments should be added as inline code review comments on the PR. + +### File: `python/sglang/srt/model_loader/loader.py` + +**Comment 1 - Line ~2340 (load_model_from_model_express, backend_field check)** +``` +⚠️ Backend Type Handling: Add support for NIXL backend error case + +Currently, the code only handles `transfer_engine_session_id` and raises a generic error for other backends. Consider adding explicit handling for NIXL: + +```python +backend_field = source_worker.WhichOneof("backend_metadata") +if backend_field == "transfer_engine_session_id": + seed_session_id = source_worker.transfer_engine_session_id +elif backend_field == "nixl_metadata": + raise RuntimeError( + f"ModelExpress: source worker {tp_rank} uses NIXL backend, " + f"but MODEL_EXPRESS backend requires TransferEngine. " + f"Please use a source with TransferEngine backend or use NIXL directly." + ) +else: + raise RuntimeError( + f"ModelExpress: unknown backend_metadata={backend_field} " + f"for worker {tp_rank}. Expected 'transfer_engine_session_id'." + ) +``` + +This provides clearer error messages when backend types don't match. +``` + +**Comment 2 - Line ~2350 (tensor descriptor conversion)** +``` +βœ… Good: Byte size matching approach + +The use of raw byte sizes (`td.size`) for matching is correct for RDMA transfers. RDMA is a memcpy operation, so byte-level matching is appropriate regardless of dtype differences (FP8 vs BF16, etc.). + +Consider adding a comment explaining this: +```python +# Convert tensor descriptors to {name: (addr, size_bytes)} format +# Use raw byte sizes -- RDMA is a memcpy, dtype matching is not required +# The model's quantization logic handles dtype conversions, not the transfer layer +seed_weight_info = {} +``` +``` + +**Comment 3 - Line ~2370 (tensor name matching)** +``` +⚠️ Error Message Enhancement: Improve missing tensor error + +When a tensor name is not found, provide more context: + +```python +for name, tensor in model.named_parameters(): + weight_info = seed_weight_info.get(name, None) + if weight_info is None: + # Provide helpful context + available_names = list(seed_weight_info.keys()) + logger.error( + f"ModelExpress: tensor '{name}' not found in seed metadata. " + f"Available tensors ({len(available_names)}): {available_names[:5]}..." + ) + raise RuntimeError( + f"ModelExpress: cannot find weight info for '{name}' " + f"in seed metadata. This may indicate a model architecture mismatch " + f"or different model versions between seed and target." + ) +``` + +This helps debug model architecture mismatches. +``` + +**Comment 4 - Line ~2280 (wait_for_ready call)** +``` +⚠️ Timeout Configuration: Make timeout configurable + +The `wait_for_ready` timeout is not visible in the code. Consider: + +```python +timeout_seconds = getattr(load_config, 'model_express_ready_timeout', 7200) # 2 hours default +ready, session_id, metadata_hash = mx_client.wait_for_ready( + model_name, worker_id=tp_rank, timeout_seconds=timeout_seconds, +) +if not ready: + raise RuntimeError( + f"ModelExpress: timed out waiting for seed ready " + f"(model={model_name}, worker={tp_rank}, timeout={timeout_seconds}s). " + f"Check that seed instance is running and has published ready flag." + ) +``` + +Also consider adding `model_express_ready_timeout` to LoadConfig and ServerArgs. +``` + +### File: `python/sglang/srt/model_executor/model_runner.py` + +**Comment 5 - Line ~700 (_publish_model_express_metadata, dtype inference)** +``` +πŸ”§ Critical: Fix dtype inference from element size + +The current mapping is lossy and can misidentify dtypes: + +```python +element_size_to_dtype = {1: "float8_e4m3fn", 2: "bfloat16", 4: "float32", 8: "float64"} +``` + +**Problem**: Multiple dtypes share the same element size: +- Size 2: `float16`, `bfloat16`, `int16`, `uint16` +- Size 1: `int8`, `uint8`, `float8_e4m3fn`, `float8_e5m2` + +**Solution**: Use actual tensor dtype: + +```python +tensors = [] +for name, (addr, numel, element_size) in weight_info.items(): + # Get actual tensor to determine dtype + param_dict = dict(self.model.named_parameters()) + if name not in param_dict: + logger.warning(f"Parameter {name} not found in model, using element_size inference") + dtype_str = element_size_to_dtype.get(element_size, "unknown") + else: + tensor = param_dict[name] + dtype_str = str(tensor.dtype).replace("torch.", "") + + tensors.append(p2p_pb2.TensorDescriptor( + name=name, + addr=addr, + size=numel * element_size, + device_id=self.gpu_id, + dtype=dtype_str, + )) +``` + +**Alternative**: Modify `register_memory_region` to return dtype as well: +```python +# In remote_instance_weight_loader_utils.py +weight_info[name] = (addr, numel, element_size, str(tensor.dtype).replace("torch.", "")) +``` +``` + +**Comment 6 - Line ~685 (model_name default)** +``` +⚠️ Model Name Default: Improve consistency + +Using `model_path` as default can lead to inconsistent model names: + +```python +model_name = ( + self.server_args.model_express_model_name + or self.server_args.model_path +) +``` + +**Issue**: `model_path` might be `/path/to/model` while target uses `meta-llama/Llama-3.1-70B`. + +**Recommendation**: +```python +model_name = self.server_args.model_express_model_name +if not model_name: + # Extract model name from model_path (last component) + import os + model_name = os.path.basename(self.server_args.model_path.rstrip('/')) + logger.warning( + f"ModelExpress: using model_name='{model_name}' from model_path. " + f"Consider setting --model-express-model-name explicitly for consistency." + ) +``` + +Or require explicit model name: +```python +if not self.server_args.model_express_model_name: + raise ValueError( + "--model-express-model-name is required when using --model-express-source" + ) +``` +``` + +**Comment 7 - Line ~1075 (model_specific_adjustment, memory registration timing)** +``` +⚠️ Memory Registration Timing: Ensure weights are stable + +The memory registration happens after model loading, but weights may be modified by `post_load_weights()`. Consider: + +```python +# In model_specific_adjustment(), before ModelExpress publish: +# Ensure model weights are finalized (post_load_weights may modify weights) +if hasattr(self.model, "post_load_weights"): + self.model.post_load_weights() + +# Now register memory regions (weights are stable) +if self.server_args.model_express_source: + if ( + self.remote_instance_transfer_engine_weight_info is None + and self.remote_instance_transfer_engine is not None + ): + self.remote_instance_transfer_engine_weight_info = ( + register_memory_region(self.model, self.remote_instance_transfer_engine) + ) + self._publish_model_express_metadata() +``` + +This ensures memory addresses remain valid after registration. +``` + +**Comment 8 - Line ~720 (publish_ready call)** +``` +πŸ“ Metadata Hash: Consider computing actual hash + +Currently, `metadata_hash` is set to empty string: + +```python +mx_client.publish_ready( + model_name, + worker_id=self.tp_rank, + session_id=mx_client.session_id, + metadata_hash="", # Empty hash +) +``` + +Consider computing an actual hash of the tensor descriptors for validation: + +```python +import hashlib +metadata_str = ",".join(sorted(f"{td.name}:{td.addr}:{td.size}" for td in tensors)) +metadata_hash = hashlib.md5(metadata_str.encode()).hexdigest() + +mx_client.publish_ready( + model_name, + worker_id=self.tp_rank, + session_id=mx_client.session_id, + metadata_hash=metadata_hash, +) +``` + +This enables target-side validation that metadata hasn't changed. +``` + +### File: `python/sglang/srt/server_args.py` + +**Comment 9 - Line ~2722 (validation logic)** +``` +⚠️ Validation: Make validation stricter + +The current validation silently falls back to `auto`: + +```python +if self.remote_instance_weight_loader_backend == "model_express": + if self.model_express_url is None: + logger.warning("Fallback load_format to 'auto'...") + self.load_format = "auto" +``` + +**Recommendation**: Raise an error instead: + +```python +if self.remote_instance_weight_loader_backend == "model_express": + if self.model_express_url is None: + raise ValueError( + "--model-express-url is required when using " + "--remote-instance-weight-loader-backend=model_express" + ) + if not self.validate_transfer_engine(): + raise ValueError( + "TransferEngine is required for model_express backend. " + "Please install mooncake.engine or use a different backend." + ) +``` + +Silent fallback can lead to confusion when users expect model_express backend. +``` + +**Comment 10 - Line ~5235 (CLI argument help text)** +``` +πŸ“ Documentation: Enhance help text + +The help text for `--model-express-source` could be more descriptive: + +```python +parser.add_argument( + "--model-express-source", + action="store_true", + help=( + "Run as a ModelExpress seed source: publish TransferEngine metadata " + "to the ModelExpress server after loading weights. " + "Requires --model-express-url and TransferEngine initialization. " + "Target instances can then load weights via --remote-instance-weight-loader-backend=model_express." + ), +) +``` + +This clarifies the relationship between source and target modes. +``` + +**Comment 11 - Line ~5783 (validate_transfer_engine, ModelExpress source check)** +``` +βœ… Good: TransferEngine validation includes ModelExpress source + +The validation correctly checks for ModelExpress source mode: + +```python +if self.model_express_source: + return True +``` + +This ensures TransferEngine is initialized when running as a seed source. +``` + +### File: `python/sglang/srt/configs/load_config.py` + +**Comment 12 - Line ~78-79 (LoadConfig fields)** +``` +βœ… Good: Clean addition of ModelExpress fields + +The addition of `model_express_url` and `model_express_model_name` to LoadConfig is clean and follows existing patterns. + +Consider adding a comment: +```python +# ModelExpress coordination fields (for remote_instance_weight_loader_backend=model_express) +model_express_url: Optional[str] = None +model_express_model_name: Optional[str] = None +``` +``` + +### Testing & Documentation + +**Comment 13 - Missing: Test Coverage** +``` +βœ… Test Coverage Needed + +Please add tests for: +1. **TP rank matching**: Verify each target rank connects to correct seed rank +2. **Byte size validation**: Test size mismatch detection +3. **Missing tensor handling**: Test behavior when tensor names don't match +4. **ModelExpress server unavailability**: Test error handling +5. **Timeout scenarios**: Test ready state timeout handling +6. **Mixed dtype models**: Test FP8 + BF16 models + +Example test structure: +```python +def test_model_express_tp_rank_matching(): + # Test that target TP rank 0 connects to seed TP rank 0 + ... + +def test_model_express_byte_size_validation(): + # Test that size mismatches are detected + ... +``` +``` + +**Comment 14 - Missing: Usage Documentation** +``` +πŸ“š Documentation Needed + +Please add documentation explaining: +1. How to set up ModelExpress server +2. How to run seed instance with `--model-express-source` +3. How to run target instance with `--remote-instance-weight-loader-backend=model_express` +4. Model name coordination requirements +5. TP rank matching requirements (seed and target must have same TP size) + +Consider adding to `docs/advanced_features/rfork.md` or creating a new section. +``` + +### Summary of Priority Comments + +**High Priority (Must Address)**: +- Comment 5: Fix dtype inference from element size (critical for correctness) +- Comment 9: Make validation stricter (prevents silent failures) +- Comment 6: Improve model name default logic (prevents coordination failures) + +**Medium Priority (Should Address)**: +- Comment 1: Add NIXL backend error handling +- Comment 3: Improve missing tensor error messages +- Comment 4: Make timeout configurable +- Comment 7: Ensure weights are stable before registration +- Comment 13: Add test coverage + +**Low Priority (Nice to Have)**: +- Comment 2: Add comment explaining byte size matching +- Comment 8: Compute actual metadata hash +- Comment 10: Enhance help text +- Comment 12: Add comments to LoadConfig +- Comment 14: Add usage documentation diff --git a/modelexpress_client/python/modelexpress/training_publisher.py b/modelexpress_client/python/modelexpress/training_publisher.py index 11160955..ad69e051 100644 --- a/modelexpress_client/python/modelexpress/training_publisher.py +++ b/modelexpress_client/python/modelexpress/training_publisher.py @@ -68,6 +68,7 @@ def __init__( self._mx_source_id: str | None = None self._model_name: str = "" self._initialized = False + self._registered = False @property def mx_source_id(self) -> str | None: @@ -155,6 +156,10 @@ def publish_weights( This is the all-at-once variant. For layer-by-layer streaming, use :meth:`publish_layer` instead. + NIXL memory regions are registered only on the first call since + parameter tensor addresses stay constant across optimizer steps. + Subsequent calls reuse the cached metadata and descriptors. + Args: named_tensors: Mapping of parameter name to GPU tensor. step: Current training step (used for version tracking). @@ -166,7 +171,13 @@ def publish_weights( if not self._initialized: raise RuntimeError("Call initialize() before publish_weights()") - self._nixl.register_tensors(named_tensors) + if not self._registered: + self._nixl.register_tensors(named_tensors) + self._registered = True + logger.info( + f"Registered {len(named_tensors)} tensors with NIXL " + f"(metadata={len(self._nixl.nixl_metadata)} bytes)" + ) metadata = self._nixl.nixl_metadata descriptors = self._nixl.tensor_descriptors From f7bcb16e31a706953f7a1c659cdd2e20ba01feed Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Fri, 10 Apr 2026 16:16:23 -0700 Subject: [PATCH 06/25] feat: add receive_weights_scratch() for cross-format RDMA transfers Allocates temporary GPU buffers matching the source's tensor layout, receives via NIXL RDMA, and yields (name, tensor) pairs in HF format. The caller's model.load_weights() handles name mapping and tensor fusion (e.g. HF q/k/v -> vLLM qkv_proj). Made-with: Cursor Signed-off-by: Kavin Krishnan --- .../python/modelexpress/refit_receiver.py | 91 +++++++++++++++++++ 1 file changed, 91 insertions(+) diff --git a/modelexpress_client/python/modelexpress/refit_receiver.py b/modelexpress_client/python/modelexpress/refit_receiver.py index 1bcbab8a..a204265f 100644 --- a/modelexpress_client/python/modelexpress/refit_receiver.py +++ b/modelexpress_client/python/modelexpress/refit_receiver.py @@ -231,6 +231,97 @@ def receive_weights( if td.name in self._nixl._tensors: yield td.name, self._nixl._tensors[td.name] + def receive_weights_scratch( + self, + source: SourceRef, + timeout_seconds: float = 300.0, + ) -> Iterator[tuple[str, torch.Tensor]]: + """Receive weights into scratch GPU buffers via NIXL RDMA. + + Unlike :meth:`receive_weights` which requires pre-registered model + buffers with matching tensor names, this method allocates temporary + GPU tensors that match the source's layout, transfers via RDMA, and + yields the results. The caller feeds these through + ``model.load_weights()`` which handles name mapping and tensor fusion. + + This is the correct approach when the source (trainer) publishes + HuggingFace-format weights but the target (vLLM) uses fused internal + parameter names. + + Args: + source: A :class:`SourceRef` obtained from :meth:`poll_for_source`. + timeout_seconds: Maximum time to wait for the RDMA transfer. + + Yields: + ``(name, tensor)`` pairs in HF checkpoint format. + """ + if not self._initialized: + raise RuntimeError("Call initialize() before receive_weights_scratch()") + + meta_resp = self._client.get_metadata( + mx_source_id=source.mx_source_id, + worker_id=source.worker_id, + ) + if not meta_resp.found: + raise RuntimeError( + f"Source {source.mx_source_id}/{source.worker_id} not found on MX Server" + ) + + worker = meta_resp.worker + source_tensors = [ + TensorDescriptor( + name=t.name, + addr=t.addr, + size=t.size, + device_id=t.device_id, + dtype=t.dtype, + ) + for t in worker.tensors + ] + + _DTYPE_MAP = { + "torch.bfloat16": torch.bfloat16, + "torch.float16": torch.float16, + "torch.float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, + } + + scratch_tensors: dict[str, torch.Tensor] = {} + for td in source_tensors: + dt = _DTYPE_MAP.get(td.dtype, torch.bfloat16) + elem_size = torch.tensor([], dtype=dt).element_size() + numel = td.size // elem_size + scratch_tensors[td.name] = torch.empty( + numel, dtype=dt, device=f"cuda:{self._device_id}" + ) + + logger.info( + f"Allocated {len(scratch_tensors)} scratch buffers " + f"({sum(t.numel() * t.element_size() for t in scratch_tensors.values()) / 1e9:.2f} GB)" + ) + + self._nixl.register_tensors(scratch_tensors) + + transferred, skipped, elapsed = self._nixl.receive_from_source( + source_metadata=worker.nixl_metadata, + source_tensors=source_tensors, + timeout_seconds=timeout_seconds, + ) + + bandwidth_gbps = (transferred * 8) / (elapsed * 1e9) if elapsed > 0 else 0.0 + logger.info( + f"RDMA transfer complete: {transferred / 1e9:.2f} GB, " + f"{len(source_tensors)} tensors, {elapsed:.2f}s, " + f"{bandwidth_gbps:.1f} Gbps (step={source.training_step})" + ) + + self._current_step = source.training_step + + for name, tensor in scratch_tensors.items(): + yield name, tensor + def receive_weights_from_metadata( self, nixl_metadata: bytes, From f978b6a915668c127bf20ab7fe120ae2b5b2fc90 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Fri, 10 Apr 2026 23:34:18 -0700 Subject: [PATCH 07/25] fix: disable transfer coalescing in receive_weights_scratch (incompatible with scratch buffers) Made-with: Cursor Signed-off-by: Kavin Krishnan --- modelexpress_client/python/modelexpress/refit_receiver.py | 1 + 1 file changed, 1 insertion(+) diff --git a/modelexpress_client/python/modelexpress/refit_receiver.py b/modelexpress_client/python/modelexpress/refit_receiver.py index a204265f..8425b7f5 100644 --- a/modelexpress_client/python/modelexpress/refit_receiver.py +++ b/modelexpress_client/python/modelexpress/refit_receiver.py @@ -308,6 +308,7 @@ def receive_weights_scratch( source_metadata=worker.nixl_metadata, source_tensors=source_tensors, timeout_seconds=timeout_seconds, + coalesce_transfers=False, ) bandwidth_gbps = (transferred * 8) / (elapsed * 1e9) if elapsed > 0 else 0.0 From 6b00b0f0941fdcf5e60f7a5b4139b05b69f8bb0c Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Sat, 11 Apr 2026 10:54:14 -0700 Subject: [PATCH 08/25] fix: accept tensor_shapes in receive_weights_scratch for correct weight reshaping Made-with: Cursor Signed-off-by: Kavin Krishnan --- modelexpress_client/python/modelexpress/refit_receiver.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/modelexpress_client/python/modelexpress/refit_receiver.py b/modelexpress_client/python/modelexpress/refit_receiver.py index 8425b7f5..26fd5df0 100644 --- a/modelexpress_client/python/modelexpress/refit_receiver.py +++ b/modelexpress_client/python/modelexpress/refit_receiver.py @@ -235,6 +235,7 @@ def receive_weights_scratch( self, source: SourceRef, timeout_seconds: float = 300.0, + tensor_shapes: dict[str, tuple[int, ...]] | None = None, ) -> Iterator[tuple[str, torch.Tensor]]: """Receive weights into scratch GPU buffers via NIXL RDMA. @@ -289,6 +290,7 @@ def receive_weights_scratch( } scratch_tensors: dict[str, torch.Tensor] = {} + scratch_shapes: dict[str, tuple[int, ...]] = {} for td in source_tensors: dt = _DTYPE_MAP.get(td.dtype, torch.bfloat16) elem_size = torch.tensor([], dtype=dt).element_size() @@ -296,6 +298,7 @@ def receive_weights_scratch( scratch_tensors[td.name] = torch.empty( numel, dtype=dt, device=f"cuda:{self._device_id}" ) + scratch_shapes[td.name] = (numel,) logger.info( f"Allocated {len(scratch_tensors)} scratch buffers " @@ -321,6 +324,8 @@ def receive_weights_scratch( self._current_step = source.training_step for name, tensor in scratch_tensors.items(): + if tensor_shapes and name in tensor_shapes: + tensor = tensor.view(tensor_shapes[name]) yield name, tensor def receive_weights_from_metadata( From 953b234f5f964ff2ecb224c0cbd11716798bda66 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Mon, 13 Apr 2026 15:54:25 -0700 Subject: [PATCH 09/25] chore: remove review/feedback docs from kavink/RL branch Made-with: Cursor Signed-off-by: Kavin Krishnan --- docs/165_review.md | 175 ------ docs/170_feedback.md | 148 ----- docs/feedback.md | 1160 -------------------------------------- docs/feedback_pr19920.md | 863 ---------------------------- 4 files changed, 2346 deletions(-) delete mode 100644 docs/165_review.md delete mode 100644 docs/170_feedback.md delete mode 100644 docs/feedback.md delete mode 100644 docs/feedback_pr19920.md diff --git a/docs/165_review.md b/docs/165_review.md deleted file mode 100644 index bca75e68..00000000 --- a/docs/165_review.md +++ /dev/null @@ -1,175 +0,0 @@ -# PR 165 Review: Metadata Resiliency Phase 1 - -Reviewer: KavinKrishnan -PR: https://github.com/ai-dynamo/modelexpress/pull/165 -Author: zhengluo-nv - -## Overall Assessment - -Good simplification. Merging ready state into WorkerRecord and eliminating the -memory/layered backends reduces code paths and configuration permutations -significantly. The UpdateStatus RPC is cleaner than the old -PublishReady/GetReady pair. Tests are solid. - -Main concerns: (1) the stability_verified removal breaks our TRT-LLM -DeepGEMM warmup workflow, (2) the retry-on-RDMA-failure path in -vllm_loader.py does not check status before re-using stale workers, and -(3) a few edge cases in the K8s backend can cause silent data loss. - -## Comments to Leave on PR - -### 1. BLOCKING - stability_verified removal breaks DeepGEMM warmup gating - -File: modelexpress_common/proto/p2p.proto, lines 62-67 (new WorkerMetadata fields) -Also: modelexpress_server/src/k8s_types.rs, lines 66-80 (new WorkerStatus struct) - -The old stability_verified field was used to gate P2P transfers until after -DeepGEMM warmup completes on the source. For DeepSeek V3 / Kimi K2.5, this -warmup takes 30-60 seconds and writes to GPU memory. Transferring weights -before it finishes produces corrupted inference. - -The new SourceStatus enum only has Initializing, Ready, Stale. There is -no state between "metadata published" and "fully warmed up and safe to transfer." - -Suggestion: Add a SOURCE_STATUS_PENDING_VERIFICATION = 4 state (as Zheng -proposed in the PR comments), or split Ready into METADATA_READY and -SERVING_READY. The source should transition: -Initializing -> PendingVerification -> Ready. Targets should only transfer -from workers in Ready status. This makes stability_verified expressible -via the status enum without needing a separate boolean. - -### 2. IMPORTANT - Target retry loop does not filter by worker status - -File: modelexpress_client/python/modelexpress/vllm_loader.py, lines 476-490 -(retry metadata refresh inside the transfer attempt loop) - -When an RDMA transfer fails and the target re-fetches metadata, it matches -workers only by worker_rank and len(w.tensors) > 0: - - response = self._mx_client.get_metadata(model_name) - for w in response.workers: - if w.worker_rank == device_id and len(w.tensors) > 0: - source_worker = w - -This does not check w.status == SOURCE_STATUS_READY. If the source restarted -and is in Initializing or Stale state, the target will attempt RDMA against -potentially invalid GPU addresses. - -The initial detection at _detect_source_worker (line ~353) correctly does: - - ready = p2p_pb2.SOURCE_STATUS_READY - for w in metadata_resp.workers: - if w.worker_rank == device_id and w.status == ready and len(w.tensors) > 0: - -So this is just the retry path missing the identical check. - -### 3. IMPORTANT - update_status call not wrapped in error handling - -File: modelexpress_client/python/modelexpress/vllm_loader.py, lines 212-219 - -After successfully publishing metadata, the source calls update_status but -does not check the return value: - - if success: - logger.info(f"[Worker {device_id}] Published metadata to MX server") - mx_client.update_status( - model_name=model_name, - worker_id=device_id, - status=p2p_pb2.SOURCE_STATUS_READY, - ) - -If this gRPC call fails (network blip, server restart), update_status -returns False but execution continues. The source thinks it published -READY, but targets polling GetMetadata will never see Ready status for -this worker -- they will see Initializing (or whatever status was set -during publish_metadata) and skip it. - -Suggestion: Check the return value and raise on failure: - - if not mx_client.update_status(...): - raise RuntimeError( - f"[Worker {device_id}] Failed to update status to READY" - ) - -### 4. NIT - K8s update_status silently returns Ok when worker not found - -File: modelexpress_server/src/metadata_backend/kubernetes.rs (update_status fn) - -When a worker ID does not exist in the CR's worker list, the K8s backend -logs at debug level and returns Ok(()): - - } else { - debug!( - "update_status: worker {} not found in CR '{}', skipping", - worker_id, cr_name - ); - return Ok(()); - } - -The Redis backend returns Err for the same case (Lua script returns 0, -check_patched converts to error). This inconsistency means callers cannot -distinguish "status updated" from "worker not found" on the K8s backend. - -Suggestion: Return Err to match Redis, or if the intent is to be lenient -(worker calls update_status before publish_metadata arrives), document -that and make the Redis backend match by returning Ok when patched == 0. - -### 5. NIT - status_proto_from_name rejects Unknown -- breaks CRD backward compat - -File: modelexpress_server/src/k8s_types.rs, lines 83-92 - -status_proto_from_name returns None for "Unknown", and the K8s backend -get_metadata converts None into a hard error. But the CRD schema defaults -status to "Unknown", so pre-existing CRs will fail to read. - -Suggestion: Map "Unknown" to Some(0) since proto defines SOURCE_STATUS_UNKNOWN = 0. - -### 6. MINOR - CRD lost all useful printer columns except Model and Age - -File: examples/p2p_transfer_k8s/deploy/persistence/crd-modelmetadata.yaml, lines 110-115 - -kubectl get modelmetadata now only shows Model and Age. Add back Workers count -and a Status summary column. - -### 7. MINOR - metadata.md just has WIP banner but keeps 600 lines of stale content - -File: docs/metadata.md, lines 1-3 - -Either update to match new architecture or delete and point to ARCHITECTURE.md. -Stale doc with one-line disclaimer is worse than no doc. - -### 8. MINOR - Dead condition types remain in CRD schema - -File: examples/p2p_transfer_k8s/deploy/persistence/crd-modelmetadata.yaml, lines 81-82 - -AllWorkersPublished and Ready conditions are defined in schema but nothing in -code populates them anymore. Remove or re-implement. - -### 9. NIT - main.rs errors do not identify which backend failed - -File: modelexpress_server/src/main.rs, lines 104-113 - -Error messages say "P2P metadata backend" without naming which backend or -connection target. Include MX_METADATA_BACKEND value in the message. - -### 10. QUESTION - Local dev story without in-memory backend - -File: layered.rs (deleted), memory.rs (deleted) - -MX_METADATA_BACKEND is now required. Local dev needs Redis or K8s. -Document the recommended local setup (Docker Compose with Redis sidecar?). - -## Summary Table - -| # | Severity | File | Lines | Topic | -|---|----------|------|-------|-------| -| 1 | BLOCKING | p2p.proto, k8s_types.rs | 62-67, 66-80 | stability_verified removal | -| 2 | IMPORTANT | vllm_loader.py | 476-490 | Retry loop missing status check | -| 3 | IMPORTANT | vllm_loader.py | 212-219 | update_status failure ignored | -| 4 | NIT | kubernetes.rs | 500-510 | Inconsistent Ok vs Err | -| 5 | NIT | k8s_types.rs | 83-92 | Unknown breaks backward compat | -| 6 | MINOR | crd-modelmetadata.yaml | 110-115 | Printer columns removed | -| 7 | MINOR | metadata.md | 1-3 | Stale doc | -| 8 | MINOR | crd-modelmetadata.yaml | 81-82 | Dead conditions | -| 9 | NIT | main.rs | 104-113 | Non-descriptive errors | -| 10 | QUESTION | layered.rs, memory.rs | deleted | Local dev story | diff --git a/docs/170_feedback.md b/docs/170_feedback.md deleted file mode 100644 index 90f2c121..00000000 --- a/docs/170_feedback.md +++ /dev/null @@ -1,148 +0,0 @@ -# PR 170 Review: Multi-Source P2P Metadata with Per-Worker APIs - -Reviewer: KavinKrishnan -PR: https://github.com/ai-dynamo/modelexpress/pull/170 -Author: zhengluo-nv - -## Overall Assessment - -Strong architectural redesign. The move from model-name keys to content-addressed -SourceIdentity (mx_source_id), per-worker publish/get, and ListSources RPC -correctly supports multiple concurrent source replicas. The two-step -ListSourcesβ†’GetMetadata flow with worker_rank filtering eliminates fan-out -RPCs. SourceTransferError for selective STALE marking is the right approach. -K8s update_status now returns Err on missing worker (CodeRabbit fix applied). - -Main concerns: (1) update_status failure in _publish_metadata_and_ready is -silently ignored, (2) TensorDescriptor lacks shape field needed for TRT-LLM -tensor reconstruction (main has it), (3) no PENDING_VERIFICATION state for -DeepGEMM warmup gating, and (4) a few doc/CRD cleanups. - -## Comments to Leave on PR - -### 1. IMPORTANT - update_status failure silently ignored in _publish_metadata_and_ready - -File: modelexpress_client/python/modelexpress/vllm_loader.py, lines 265-275 - -After successfully publishing metadata, the source calls update_status with -SOURCE_STATUS_READY. If this gRPC call fails (network blip, server restart), -the code only logs and continues: - -```python -success = mx_client.update_status( - mx_source_id=mx_source_id, - worker_id=worker_id, - worker_rank=global_rank, - status=p2p_pb2.SOURCE_STATUS_READY, -) -if not success: - logger.error( - f"[Worker {global_rank}] UpdateStatus to READY failed for " - f"model '{identity.model_name}' (mx_source_id={mx_source_id})" - ) -``` - -The source thinks it is ready, but targets never see Ready status and will -never discover this worker. Same issue as PR 165 #3. - -Suggestion: Check the return value and raise on failure so the source retries -or fails loudly instead of advertising readiness that targets cannot use. - -### 2. IMPORTANT - TensorDescriptor missing shape field - -File: modelexpress_common/proto/p2p.proto, lines 91-104 (TensorDescriptor message) - -PR 170's TensorDescriptor has name, addr, size, device_id, dtype but no shape. -Main branch (and PR 169) added `repeated int64 shape = 6` for proper tensor -reconstruction on the target. TRT-LLM and some vLLM models need shape to -correctly rebuild tensors after RDMA receive. - -Suggestion: Add `repeated int64 shape = 6` to TensorDescriptor and regenerate -stubs. Ensure vllm_loader and trtllm_loader pass shape when building -TensorDescriptor protos. - -### 3. BLOCKING (for TRT-LLM) - No PENDING_VERIFICATION state for DeepGEMM warmup - -File: modelexpress_common/proto/p2p.proto, lines 112-117 (SourceStatus enum) -Also: modelexpress_server/src/k8s_types.rs, lines 89-98 (status_name_from_proto) - -The SourceStatus enum has Unknown, Initializing, Ready, Stale. There is no -state between "metadata published" and "fully warmed up and safe to transfer." -For TRT-LLM DeepGEMM warmup (DeepSeek V3, Kimi K2.5), warmup takes 30-60 seconds -and writes to GPU memory. Transferring before it finishes produces corrupted -inference. - -Commit c75a58e had PENDING_VERIFICATION but a6cbdf5 reverted it to Unknown. -Suggestion: Re-add SOURCE_STATUS_PENDING_VERIFICATION = 4 (or use value that -does not shift Ready/Stale). Source transitions: Initializing -> -PendingVerification -> Ready. Targets only transfer from Ready. - -### 4. NIT - validate_identity only checks model_name - -File: modelexpress_server/src/source_identity.rs, lines 25-30 - -validate_identity only checks identity.model_name. SourceIdentity includes -backend_framework and mx_source_type. backend_framework=0 (UNKNOWN) may -indicate uninitialized or malformed identity. - -Suggestion (optional): Add validation for backend_framework when -BACKEND_FRAMEWORK_UNKNOWN should never be published. Return Err with clear -message so malformed identities are rejected early. - -### 5. MINOR - CRD printer columns reduced to Model and Age - -File: examples/p2p_transfer_k8s/deploy/persistence/crd-modelmetadata.yaml, lines 119-126 - -kubectl get modelmetadata now only shows Model and Age. Add back Workers count -and optionally a Status summary column for easier debugging. - -### 6. MINOR - Dead condition types in CRD schema - -File: examples/p2p_transfer_k8s/deploy/persistence/crd-modelmetadata.yaml, lines 84-86 - -AllWorkersPublished and Ready conditions are defined in the schema enum but -nothing in code populates them. Remove or re-implement. - -### 7. NIT - Docstring coverage below threshold - -Pre-merge check reports docstring coverage 62.88% (required 80%). Add -docstrings for functions missing them to satisfy the threshold. - -### 8. QUESTION - Stale source detection latency (~35s per dead source) - -PR description notes: CRDs from dead pods remain "Ready" until a new target -tries them and gets NIXL_ERR_REMOTE_DISCONNECT; UCX connection timeout is -~35s per stale source. Is there a plan for heartbeat/TTL to mark stale workers -automatically? Document as known limitation or track as follow-up. - -### 9. NIT - _collect_cuda_tensors vs _iter_module_tensors - -File: modelexpress_client/python/modelexpress/vllm_loader.py - -PR 170 uses _collect_cuda_tensors (named_parameters only) instead of the -main-branch _iter_module_tensors which also finds buffers and tensor -attributes (e.g. FP8 scale_inv). For FP8 models, scale tensors may be -missed. Verify this is intentional or restore the more thorough traversal. - -### 10. MINOR - main.rs errors do not identify which backend failed - -File: modelexpress_server/src/main.rs (if present in PR 170) - -Error messages that say "P2P metadata backend" without naming which backend -or connection target make debugging harder. Include MX_METADATA_BACKEND value -(or equivalent) in the message. - -## Summary Table - -| # | Severity | File | Lines | Topic | -|---|----------|------|-------|-------| -| 1 | IMPORTANT | vllm_loader.py | 265-275 | update_status failure ignored | -| 2 | IMPORTANT | p2p.proto | 91-104 | TensorDescriptor missing shape | -| 3 | BLOCKING (TRT-LLM) | p2p.proto, k8s_types.rs | 112-117, 89-98 | No PendingVerification for warmup | -| 4 | NIT | source_identity.rs | 25-30 | validate_identity scope | -| 5 | MINOR | crd-modelmetadata.yaml | 119-126 | Printer columns | -| 6 | MINOR | crd-modelmetadata.yaml | 84-86 | Dead conditions | -| 7 | NIT | (various) | β€” | Docstring coverage | -| 8 | QUESTION | β€” | β€” | Stale detection / heartbeat | -| 9 | NIT | vllm_loader.py | _collect_cuda_tensors | FP8 scale tensors | -| 10 | MINOR | main.rs | β€” | Non-descriptive backend errors | diff --git a/docs/feedback.md b/docs/feedback.md deleted file mode 100644 index ac455abf..00000000 --- a/docs/feedback.md +++ /dev/null @@ -1,1160 +0,0 @@ -# PR 157: Add TransferEngine Backend to P2P Metadata - Design Review & Feedback - -## Executive Summary - -This document provides a design overview and feedback for PR 157, which adds TransferEngine backend support to ModelExpress's P2P metadata system. The review is informed by: -- Current ModelExpress P2P metadata architecture (NIXL-based) -- SGLang's R-Fork implementation using TransferEngine -- Best practices for multi-backend transfer systems - -## Current Architecture Overview - -### Existing P2P Metadata System - -ModelExpress currently supports P2P weight transfers using **NIXL** (NVIDIA Inter-Node eXchange Library) for RDMA-based GPU-to-GPU transfers: - -1. **Metadata Structure**: - - `WorkerMetadata` contains `nixl_metadata` (byte blob) + tensor descriptors - - Metadata is published via gRPC to ModelExpress server - - Server stores metadata in Redis/Kubernetes/In-memory backends - -2. **Transfer Flow**: - - Source: Loads model β†’ Registers tensors with NIXL β†’ Publishes metadata β†’ Signals ready - - Target: Queries metadata β†’ Adds remote NIXL agents β†’ Executes RDMA transfers - -3. **Backend Abstraction**: - - Server-side: `MetadataBackend` trait (Memory/Redis/Kubernetes) - - Client-side: `NixlTransferManager` for NIXL operations - -## Proposed Design: TransferEngine Backend Support - -### Design Goals (Inferred from SGLang R-Fork) - -Based on [SGLang's R-Fork documentation](https://raw.githubusercontent.com/sgl-project/sglang/main/docs/advanced_features/rfork.md), TransferEngine support should: - -1. **Enable zero-copy weight loading** from running instances -2. **Support multiple backends**: NCCL, TransferEngine (and potentially NIXL) -3. **Backend selection** based on availability and configuration -4. **Metadata routing** to appropriate backend based on backend type - -### Expected Changes - -PR 157 likely introduces: - -1. **Protocol Buffer Updates** (`p2p.proto`): - - Add `backend_type` field to `WorkerMetadata` (enum: NIXL, TRANSFER_ENGINE, NCCL) - - Add TransferEngine-specific metadata fields (connection info, ports, etc.) - - Maintain backward compatibility with existing NIXL-only deployments - -2. **Server-Side Changes**: - - Extend `WorkerRecord` to store backend type - - Update metadata serialization/deserialization - - Ensure backend-agnostic storage (metadata backend should not care about transfer backend) - -3. **Client-Side Changes**: - - Add `TransferEngineTransferManager` (parallel to `NixlTransferManager`) - - Backend selection logic (NIXL vs TransferEngine) - - TransferEngine-specific connection establishment - -## Design Feedback & Recommendations - -### 1. Protocol Buffer Design - -#### βœ… **Recommendation: Use OneOf for Backend-Specific Metadata** - -**Current Approach (Inferred)**: -```protobuf -message WorkerMetadata { - uint32 worker_rank = 1; - bytes nixl_metadata = 2; // Only NIXL - repeated TensorDescriptor tensors = 3; -} -``` - -**Recommended Approach**: -```protobuf -message WorkerMetadata { - uint32 worker_rank = 1; - - // Backend type determines which metadata field is populated - BackendType backend_type = 2; - - // Backend-specific metadata (one of these is populated) - oneof backend_metadata { - NixlBackendMetadata nixl_metadata = 3; - TransferEngineBackendMetadata transfer_engine_metadata = 4; - NcclBackendMetadata nccl_metadata = 5; // Future-proofing - } - - repeated TensorDescriptor tensors = 6; -} - -enum BackendType { - BACKEND_TYPE_UNSPECIFIED = 0; - BACKEND_TYPE_NIXL = 1; - BACKEND_TYPE_TRANSFER_ENGINE = 2; - BACKEND_TYPE_NCCL = 3; -} - -message NixlBackendMetadata { - bytes nixl_agent_metadata = 1; // Serialized NIXL agent blob -} - -message TransferEngineBackendMetadata { - // Connection information for TransferEngine - string seed_instance_ip = 1; - uint32 seed_instance_service_port = 2; - repeated uint32 send_weights_group_ports = 3; // For NCCL backend - // Additional TransferEngine-specific fields as needed -} -``` - -**Rationale**: -- **Type Safety**: Clear separation of backend-specific metadata -- **Extensibility**: Easy to add new backends (NCCL, custom) -- **Backward Compatibility**: Can deprecate old `nixl_metadata` field gradually -- **Validation**: Server can validate that backend_type matches populated metadata - -#### ⚠️ **Concern: Backward Compatibility** - -**Issue**: Existing deployments use `bytes nixl_metadata`. How does PR 157 handle migration? - -**Recommendations**: -1. **Deprecation Strategy**: Keep `nixl_metadata` field but mark as deprecated -2. **Migration Path**: Server should accept both old and new formats during transition -3. **Auto-Detection**: If `backend_type` is unset but `nixl_metadata` is present, infer `BACKEND_TYPE_NIXL` - -**Example Migration Code**: -```rust -impl From for WorkerRecord { - fn from(meta: WorkerMetadata) -> Self { - let (backend_type, metadata_bytes) = match meta.backend_type { - BackendType::Nixl | BackendType::Unspecified => { - // Handle legacy: if backend_type unset but nixl_metadata present - if !meta.nixl_metadata.is_empty() { - (BackendType::Nixl, meta.nixl_metadata) - } else if let Some(nixl) = meta.backend_metadata.nixl_metadata { - (BackendType::Nixl, nixl.nixl_agent_metadata) - } else { - // Error: no metadata - return Err(...); - } - } - BackendType::TransferEngine => { - if let Some(te) = meta.backend_metadata.transfer_engine_metadata { - // Serialize TransferEngine metadata - (BackendType::TransferEngine, serialize_te_metadata(te)?) - } else { - return Err(...); - } - } - }; - - Self { - worker_rank: meta.worker_rank, - backend_type, - backend_metadata: metadata_bytes, - tensors: ... - } - } -} -``` - -### 2. Server-Side Storage Design - -#### βœ… **Recommendation: Store Backend Type in WorkerRecord** - -**Current Structure**: -```rust -pub struct WorkerRecord { - pub worker_rank: u32, - pub nixl_metadata: Vec, // Backend-agnostic name needed - pub tensors: Vec, -} -``` - -**Recommended Structure**: -```rust -pub struct WorkerRecord { - pub worker_rank: u32, - pub backend_type: BackendType, // NEW: Track backend type - pub backend_metadata: Vec, // RENAMED: Generic name (was nixl_metadata) - pub tensors: Vec, -} -``` - -**Rationale**: -- **Clarity**: `backend_metadata` is more accurate than `nixl_metadata` -- **Type Safety**: Backend type is explicit in storage layer -- **Query Support**: Can filter/query by backend type if needed - -#### ⚠️ **Concern: Storage Backend Compatibility** - -**Issue**: Redis/Kubernetes backends serialize `WorkerRecord`. How does PR 157 handle: -1. Existing stored data (only NIXL)? -2. Mixed deployments (some workers NIXL, some TransferEngine)? - -**Recommendations**: -1. **Default Backend Type**: When deserializing old data without `backend_type`, default to `BackendType::Nixl` -2. **Versioned Schema**: Consider adding a `schema_version` field for future migrations -3. **Validation**: Reject metadata where backend_type doesn't match metadata format - -**Example**: -```rust -impl From for WorkerRecord { - fn from(json: WorkerRecordJson) -> Self { - Self { - worker_rank: json.worker_rank, - backend_type: json.backend_type.unwrap_or(BackendType::Nixl), // Default for old data - backend_metadata: json.backend_metadata, // Was nixl_metadata - tensors: ... - } - } -} -``` - -### 3. Client-Side Backend Selection - -#### βœ… **Recommendation: Factory Pattern for Transfer Managers** - -**Current Approach**: -```python -class NixlTransferManager: - def __init__(self, agent_name: str, device_id: int): - ... -``` - -**Recommended Approach**: -```python -class TransferManagerFactory: - @staticmethod - def create( - backend_type: BackendType, - agent_name: str, - device_id: int, - **kwargs - ) -> TransferManager: - if backend_type == BackendType.NIXL: - return NixlTransferManager(agent_name, device_id) - elif backend_type == BackendType.TRANSFER_ENGINE: - return TransferEngineTransferManager( - agent_name, device_id, - seed_instance_ip=kwargs.get("seed_instance_ip"), - seed_instance_port=kwargs.get("seed_instance_port"), - ... - ) - else: - raise ValueError(f"Unsupported backend: {backend_type}") - -# Usage -metadata = get_metadata_from_server(model_name) -for worker in metadata.workers: - manager = TransferManagerFactory.create( - backend_type=worker.backend_type, - agent_name=f"worker_{worker.worker_rank}", - device_id=worker.worker_rank, - **extract_transfer_engine_config(worker.backend_metadata) - ) -``` - -**Rationale**: -- **Clean Separation**: Each backend has its own manager -- **Easy Testing**: Can mock individual backends -- **Configuration**: Backend-specific config passed via kwargs - -#### ⚠️ **Concern: Backend Availability Detection** - -**Issue**: How does the client know which backends are available at runtime? - -**Recommendations**: -1. **Runtime Detection**: Check for NIXL/TransferEngine availability (similar to `is_nixl_available()`) -2. **Fallback Strategy**: If preferred backend unavailable, fall back to alternative -3. **Error Messages**: Clear errors when required backend is missing - -**Example**: -```python -def select_backend(preferred: BackendType) -> BackendType: - """Select available backend with fallback.""" - if preferred == BackendType.TRANSFER_ENGINE: - if is_transfer_engine_available(): - return BackendType.TRANSFER_ENGINE - elif is_nixl_available(): - logger.warning("TransferEngine not available, falling back to NIXL") - return BackendType.NIXL - else: - raise RuntimeError("No transfer backend available") - elif preferred == BackendType.NIXL: - if is_nixl_available(): - return BackendType.NIXL - else: - raise RuntimeError("NIXL not available") - ... -``` - -### 4. Alignment with SGLang R-Fork - -#### βœ… **Recommendation: Match SGLang's Configuration Pattern** - -SGLang uses command-line arguments for TransferEngine configuration: -```bash ---load-format remote_instance ---remote-instance-weight-loader-backend transfer_engine ---remote-instance-weight-loader-seed-instance-ip ---remote-instance-weight-loader-seed-instance-service-port -``` - -**ModelExpress Equivalent**: -```python -# Environment variables or config -MX_TRANSFER_BACKEND=transfer_engine -MX_TRANSFER_ENGINE_SEED_IP= -MX_TRANSFER_ENGINE_SEED_PORT= -``` - -**Recommendations**: -1. **Consistent Naming**: Use similar parameter names to SGLang for familiarity -2. **Documentation**: Reference SGLang's R-Fork docs in ModelExpress docs -3. **Validation**: Validate that seed instance is reachable before publishing metadata - -### 5. Metadata Exchange & Routing - -#### βœ… **Recommendation: Backend-Aware Metadata Routing** - -**Issue**: When target receives metadata, it must route to correct backend. - -**Current Flow**: -``` -Target β†’ GetMetadata(model_name) β†’ Server β†’ Returns WorkerMetadata -Target β†’ Extract nixl_metadata β†’ Add remote NIXL agent -``` - -**Recommended Flow**: -``` -Target β†’ GetMetadata(model_name) β†’ Server β†’ Returns WorkerMetadata (with backend_type) -Target β†’ Check backend_type β†’ Route to appropriate manager: - - NIXL β†’ NixlTransferManager.add_remote_agent(nixl_metadata) - - TransferEngine β†’ TransferEngineTransferManager.connect(te_metadata) -``` - -**Implementation**: -```python -def load_model_from_source(model_name: str): - metadata = client.get_metadata(model_name) - - for worker in metadata.workers: - if worker.backend_type == BackendType.NIXL: - manager = get_nixl_manager(worker.worker_rank) - manager.add_remote_agent(worker.backend_metadata) - elif worker.backend_type == BackendType.TRANSFER_ENGINE: - manager = get_transfer_engine_manager(worker.worker_rank) - te_config = deserialize_transfer_engine_metadata(worker.backend_metadata) - manager.connect_to_seed(te_config) -``` - -### 6. Error Handling & Validation - -#### ⚠️ **Concerns** - -1. **Mismatched Backends**: What if source uses TransferEngine but target only has NIXL? -2. **Metadata Corruption**: Invalid backend_metadata for declared backend_type -3. **Connection Failures**: TransferEngine seed instance unreachable - -**Recommendations**: -1. **Validation**: Server should validate backend_type matches metadata format -2. **Error Messages**: Clear errors: "Source uses TransferEngine but target only supports NIXL" -3. **Fallback**: Consider automatic fallback if preferred backend unavailable (with user opt-in) - -**Example Validation**: -```rust -fn validate_worker_metadata(worker: &WorkerMetadata) -> Result<()> { - match worker.backend_type { - BackendType::Nixl => { - if worker.backend_metadata.is_empty() { - return Err("NIXL backend requires non-empty metadata"); - } - // Could also validate NIXL metadata format - } - BackendType::TransferEngine => { - let te_meta = deserialize_transfer_engine_metadata(&worker.backend_metadata)?; - if te_meta.seed_instance_ip.is_empty() { - return Err("TransferEngine requires seed_instance_ip"); - } - } - _ => return Err("Unsupported backend type"), - } - Ok(()) -} -``` - -### 7. Testing & Compatibility - -#### βœ… **Recommendations** - -1. **Unit Tests**: - - Test backend type serialization/deserialization - - Test migration from old format (nixl_metadata) to new format - - Test validation logic - -2. **Integration Tests**: - - Test NIXL-only deployment (backward compatibility) - - Test TransferEngine-only deployment - - Test mixed deployment (some workers NIXL, some TransferEngine) - -3. **Compatibility Tests**: - - Old client β†’ New server (should work) - - New client β†’ Old server (should handle gracefully) - -**Example Test**: -```rust -#[test] -fn test_backward_compatibility_old_nixl_metadata() { - // Simulate old WorkerMetadata with only nixl_metadata field - let old_meta = WorkerMetadata { - worker_rank: 0, - backend_type: BackendType::Unspecified, // Old format - nixl_metadata: vec![1, 2, 3, 4], // Old field - backend_metadata: None, // New field not set - tensors: vec![], - }; - - let record = WorkerRecord::from(old_meta); - assert_eq!(record.backend_type, BackendType::Nixl); // Auto-detected - assert_eq!(record.backend_metadata, vec![1, 2, 3, 4]); -} -``` - -## Specific PR Feedback Items - -### High Priority - -1. **Backward Compatibility**: Ensure existing NIXL-only deployments continue to work without changes -2. **Protocol Buffer Design**: Use `oneof` for backend-specific metadata (see Section 1) -3. **Storage Layer**: Rename `nixl_metadata` to `backend_metadata` and add `backend_type` field -4. **Validation**: Add server-side validation that backend_type matches metadata format - -### Medium Priority - -5. **Client Factory**: Implement factory pattern for transfer manager creation -6. **Error Handling**: Clear error messages for backend mismatches -7. **Documentation**: Update `docs/metadata.md` with TransferEngine backend information -8. **Configuration**: Align parameter names with SGLang's R-Fork for consistency - -### Low Priority - -9. **Future-Proofing**: Consider NCCL backend support (similar pattern) -10. **Observability**: Add metrics/logging for backend type usage -11. **Testing**: Comprehensive test coverage for all backend combinations - -## Questions for PR Author - -1. **Migration Strategy**: How are existing deployments migrated? Is there a migration script? -2. **Backend Selection**: How does the system decide which backend to use? User config or auto-detection? -3. **Mixed Deployments**: Can a single model have workers using different backends (e.g., worker 0 NIXL, worker 1 TransferEngine)? -4. **TransferEngine Implementation**: Is TransferEngine a separate library, or is it part of NIXL? What are the dependencies? -5. **Performance Comparison**: Are there benchmarks comparing NIXL vs TransferEngine performance? -6. **SGLang Integration**: Is this change intended to enable ModelExpress to work with SGLang's R-Fork feature? - -## Conclusion - -The addition of TransferEngine backend support is a valuable enhancement that aligns ModelExpress with SGLang's R-Fork capabilities. The key concerns are: - -1. **Design**: Use `oneof` for backend-specific metadata to ensure type safety and extensibility -2. **Compatibility**: Maintain backward compatibility with existing NIXL deployments -3. **Validation**: Ensure backend_type and metadata format are consistent -4. **Testing**: Comprehensive test coverage for all scenarios - -The recommended approach provides a clean, extensible design that can support additional backends (NCCL, custom) in the future while maintaining compatibility with existing deployments. - ---- - -## PR Review Comments - -This section provides specific comments to make directly on PR 157, organized by file and approximate line numbers. These comments should be added as inline code review comments on the PR. - -**Note**: These comments are based on the actual implementation in the `ishan/transfer-engine-backend` branch. The PR has already implemented the `oneof` pattern for backend metadata, which is excellent! The comments below address specific aspects of the implementation. - -### Protocol Buffer Changes - -#### File: `modelexpress_common/proto/p2p.proto` - -**Comment 1 - Line ~57-60 (WorkerMetadata message)** -``` -βœ… Excellent: Using `oneof` for backend metadata! - -Great implementation! The `oneof backend_metadata` pattern provides type safety -and clear separation. One suggestion: - -Consider adding a comment explaining the format of `transfer_engine_session_id`: -```protobuf -// TransferEngine: Mooncake session ID in format "ip:port" (e.g., "10.0.0.1:8000") -string transfer_engine_session_id = 10; -``` - -This helps users understand the expected format. Also, consider if a structured -message would be better for future extensibility (e.g., if you need to add -additional TransferEngine connection parameters later). -``` - -**Comment 2 - Line ~50-60 (WorkerMetadata message)** -``` -⚠️ Backward Compatibility Concern - -If the existing `bytes nixl_metadata = 2` field is being kept for compatibility, -please ensure: - -1. The field is marked as deprecated in comments -2. Server-side conversion handles both old and new formats -3. Auto-detection: If `backend_type` is unset but `nixl_metadata` is present, - infer `BACKEND_TYPE_NIXL` - -This is critical for existing deployments that won't be updated immediately. -``` - -**Comment 3 - Line ~92-97 (if BackendType enum is added)** -``` -βœ… Good: BackendType enum definition - -If adding a BackendType enum, ensure: -- `BACKEND_TYPE_UNSPECIFIED = 0` is the default (protobuf best practice) -- Values match the pattern used in SGLang's R-Fork for consistency -- Consider future-proofing with `BACKEND_TYPE_NCCL = 3` even if not implemented yet -``` - -**Comment 4 - Line ~103-109 (if TransferEngineBackendMetadata message is added)** -``` -πŸ“ Documentation Suggestion - -The TransferEngineBackendMetadata message should include: -- `seed_instance_ip`: IP address of seed instance (required) -- `seed_instance_service_port`: HTTP service port (required) -- `send_weights_group_ports`: For NCCL backend variant (optional, repeated) -- Comments explaining each field's purpose - -Consider aligning field names with SGLang's R-Fork parameters for familiarity: -- `--remote-instance-weight-loader-seed-instance-ip` -- `--remote-instance-weight-loader-seed-instance-service-port` -``` - -### Server-Side Rust Changes - -#### File: `modelexpress_server/src/metadata_backend.rs` - -**Comment 5 - Line ~64-68 (WorkerRecord struct)** -``` -βœ… Excellent: Using `BackendMetadataRecord` enum! - -Great design! The `BackendMetadataRecord` enum provides type safety and makes -the backend type explicit. The implementation looks clean. - -One observation: The `BackendMetadataRecord::None` variant (line 43) - is this -intentionally allowed? If a worker has no backend metadata, should we reject -it during validation, or is this for a specific use case? Consider adding -validation in `publish_metadata` to ensure at least one backend is provided. -``` - -**Comment 6 - Line ~81-96 (From for WorkerRecord)** -``` -βœ… Clean Implementation: Conversion logic looks good! - -The conversion from `WorkerMetadata` to `WorkerRecord` correctly handles the -`oneof` pattern. One suggestion: - -Consider adding validation to ensure at least one backend metadata is provided: -```rust -impl From for WorkerRecord { - fn from(meta: WorkerMetadata) -> Self { - use modelexpress_common::grpc::p2p::worker_metadata::BackendMetadata; - let backend_metadata = match meta.backend_metadata { - Some(BackendMetadata::NixlMetadata(data)) => { - if data.is_empty() { - tracing::warn!("Empty NIXL metadata for worker {}", meta.worker_rank); - } - BackendMetadataRecord::Nixl(data) - } - Some(BackendMetadata::TransferEngineSessionId(sid)) => { - if sid.is_empty() { - tracing::warn!("Empty TransferEngine session ID for worker {}", meta.worker_rank); - } - BackendMetadataRecord::TransferEngine(sid) - } - None => { - tracing::warn!("No backend metadata provided for worker {}", meta.worker_rank); - BackendMetadataRecord::None - } - }; - ... - } -} -``` - -This helps catch configuration errors early. -``` - -**Comment 7 - Line ~77-88 (From for WorkerMetadata)** -``` -πŸ”„ Conversion Logic: Ensure bidirectional conversion works - -The reverse conversion `From for WorkerMetadata` must: -1. Set `backend_type` field correctly -2. Populate the appropriate `oneof` field based on `backend_type` -3. Handle legacy `nixl_metadata` field for backward compatibility - -This ensures targets can correctly deserialize and route to the right backend. -``` - -#### File: `modelexpress_server/src/p2p_service.rs` - -**Comment 8 - Line ~49-59 (BackendMetadataRecord::from_flat)** -``` -⚠️ Priority Logic: TransferEngine takes priority - -The `from_flat` method gives TransferEngine priority when both `nixl_metadata` -and `transfer_engine_session_id` are present (line 50-53). This is reasonable, -but consider: - -1. **Documentation**: Add a comment explaining why TransferEngine takes priority -2. **Validation**: Should we warn or error if both are provided? It might indicate - a configuration mistake -3. **Consistency**: Ensure this priority is consistent across all code paths - -Suggestion: -```rust -pub fn from_flat(nixl_metadata: Vec, transfer_engine_session_id: Option) -> Self { - if let Some(sid) = transfer_engine_session_id - && !sid.is_empty() - { - // TransferEngine takes priority when both are present - if !nixl_metadata.is_empty() { - tracing::warn!( - "Both NIXL and TransferEngine metadata provided, using TransferEngine" - ); - } - return Self::TransferEngine(sid); - } - ... -} -``` -``` - -**Comment 9 - Line ~84-119 (get_metadata implementation)** -``` -πŸ“Š Logging Enhancement - -When returning metadata, log the backend types being returned: -```rust -info!( - "Found metadata for model '{}': {} workers (backends: {:?}), {} tensors", - req.model_name, - record.workers.len(), - record.workers.iter().map(|w| w.backend_type).collect::>(), - total_tensors -); -``` - -This helps with debugging mixed-backend deployments. -``` - -### Client-Side Python Changes - -#### File: `modelexpress_client/python/modelexpress/nixl_transfer.py` (or new file) - -**Comment 10 - Line ~1-50 (if creating TransferEngineTransferManager)** -``` -🏭 Factory Pattern Suggestion - -Consider creating a factory for transfer managers to handle backend selection: - -```python -class TransferManagerFactory: - @staticmethod - def create( - backend_type: BackendType, - agent_name: str, - device_id: int, - **kwargs - ) -> TransferManager: - if backend_type == BackendType.NIXL: - return NixlTransferManager(agent_name, device_id) - elif backend_type == BackendType.TRANSFER_ENGINE: - return TransferEngineTransferManager( - agent_name, device_id, - seed_instance_ip=kwargs.get("seed_instance_ip"), - seed_instance_port=kwargs.get("seed_instance_port"), - ) - else: - raise ValueError(f"Unsupported backend: {backend_type}") -``` - -This provides clean separation and makes testing easier. -``` - -**Comment 11 - Line ~37-40 (is_nixl_available function)** -``` -πŸ” Backend Availability Detection - -Add a similar function for TransferEngine: -```python -def is_transfer_engine_available() -> bool: - """Check if TransferEngine is available.""" - try: - # Import TransferEngine library - from transfer_engine import TransferEngine - return True - except ImportError: - return False -``` - -Also consider a backend selection function with fallback: -```python -def select_available_backend(preferred: BackendType) -> BackendType: - """Select available backend with fallback.""" - if preferred == BackendType.TRANSFER_ENGINE: - if is_transfer_engine_available(): - return BackendType.TRANSFER_ENGINE - elif is_nixl_available(): - logger.warning("TransferEngine not available, falling back to NIXL") - return BackendType.NIXL - ... -``` -``` - -#### File: `modelexpress_client/python/modelexpress/` (vLLM loader integration) - -**Comment 12 - Line ~TBD (where metadata is consumed)** -``` -πŸ”„ Backend-Aware Routing Required - -When target receives metadata from `get_metadata()`, ensure it routes to the -correct backend based on `backend_type`: - -```python -def load_model_from_source(model_name: str): - metadata = client.get_metadata(model_name) - - for worker in metadata.workers: - if worker.backend_type == BackendType.NIXL: - manager = get_nixl_manager(worker.worker_rank) - manager.add_remote_agent(worker.backend_metadata) - elif worker.backend_type == BackendType.TRANSFER_ENGINE: - manager = get_transfer_engine_manager(worker.worker_rank) - te_config = deserialize_transfer_engine_metadata(worker.backend_metadata) - manager.connect_to_seed(te_config) - else: - raise ValueError(f"Unsupported backend: {worker.backend_type}") -``` - -This ensures targets can handle sources using different backends. -``` - -**Comment 13 - Line ~TBD (error handling)** -``` -⚠️ Error Handling: Backend Mismatch - -Add clear error handling when source and target backends don't match: - -```python -if worker.backend_type == BackendType.TRANSFER_ENGINE: - if not is_transfer_engine_available(): - raise RuntimeError( - f"Source worker {worker.worker_rank} uses TransferEngine backend, " - "but TransferEngine is not available on this target. " - "Please install TransferEngine or use a source with NIXL backend." - ) -``` - -Provide actionable error messages to help users resolve issues. -``` - -### Storage Backend Changes - -#### File: `modelexpress_server/src/metadata_backend/redis.rs` - -**Comment 14 - Line ~125-147 (WorkerRecordJson)** -``` -πŸ”„ JSON Serialization: Handle backend_type field - -The `WorkerRecordJson` struct needs to include `backend_type`: - -```rust -#[derive(Debug, Clone, Serialize, Deserialize)] -struct WorkerRecordJson { - pub worker_rank: u32, - pub backend_type: Option, // NEW: Optional for backward compat - pub backend_metadata: Vec, // RENAMED from nixl_metadata - pub tensors: Vec, -} -``` - -In `From for WorkerRecord`, default to `BackendType::Nixl` -if `backend_type` is `None` (for old stored data). -``` - -#### File: `modelexpress_server/src/metadata_backend/kubernetes.rs` - -**Comment 15 - Line ~TBD (WorkerStatus in k8s_types.rs)** -``` -πŸ“ Kubernetes CRD: Add backend_type field - -The `WorkerStatus` struct in `k8s_types.rs` should include: -```rust -pub struct WorkerStatus { - pub worker_rank: i32, - pub backend_type: Option, // "nixl", "transfer_engine", etc. - pub nixl_metadata: String, // Consider renaming to backend_metadata - ... -} -``` - -Update the CRD schema in `examples/p2p_transfer_k8s/deploy/persistence/crd-modelmetadata.yaml` -to include the backend_type field. -``` - -### Testing - -#### File: `modelexpress_server/src/metadata_backend.rs` (test module) - -**Comment 16 - Line ~TBD (add new tests)** -``` -βœ… Test Coverage Needed - -Please add tests for: -1. **Backward compatibility**: Old WorkerMetadata with only `nixl_metadata` field -2. **New format**: WorkerMetadata with `backend_type` and `oneof` fields -3. **Migration**: Conversion from old to new format -4. **Validation**: Reject invalid backend_type/metadata combinations -5. **Mixed deployments**: Model with some workers NIXL, some TransferEngine - -Example: -```rust -#[test] -fn test_backward_compatibility_old_nixl_metadata() { - let old_meta = WorkerMetadata { - worker_rank: 0, - backend_type: BackendType::Unspecified, - nixl_metadata: vec![1, 2, 3, 4], // Old field - backend_metadata: None, // New field not set - tensors: vec![], - }; - - let record = WorkerRecord::from(old_meta); - assert_eq!(record.backend_type, BackendType::Nixl); // Auto-detected -} -``` -``` - -### Documentation - -#### File: `docs/metadata.md` - -**Comment 17 - Line ~1-10 (Overview section)** -``` -πŸ“ Documentation Update Needed - -Please update the overview to mention TransferEngine backend support: - -```markdown -## Overview - -ModelExpress P2P transfers require coordination between source and target instances: -1. **Source** publishes transfer backend metadata (NIXL agent info or TransferEngine - connection info + tensor descriptors) after loading model weights -2. **Target** queries for source metadata to establish connections (RDMA for NIXL, - TransferEngine connection for TransferEngine backend) -3. **Coordination** signals ensure targets wait for sources to be fully ready -``` - -Also add a new section explaining TransferEngine backend usage and configuration. -``` - -**Comment 18 - Line ~TBD (add TransferEngine section)** -``` -πŸ“š New Section: TransferEngine Backend - -Add a section explaining: -1. When to use TransferEngine vs NIXL -2. Configuration parameters (align with SGLang R-Fork) -3. Example usage -4. Troubleshooting common issues - -Reference: https://raw.githubusercontent.com/sgl-project/sglang/main/docs/advanced_features/rfork.md -``` - -### Configuration & Environment Variables - -#### File: `README.md` or new config documentation - -**Comment 19 - Line ~TBD** -``` -βš™οΈ Configuration Documentation - -Document the new environment variables for TransferEngine: -- `MX_TRANSFER_BACKEND`: Backend type (`nixl`, `transfer_engine`, default: `nixl`) -- `MX_TRANSFER_ENGINE_SEED_IP`: Seed instance IP (required for TransferEngine) -- `MX_TRANSFER_ENGINE_SEED_PORT`: Seed instance service port (required for TransferEngine) - -Align naming with SGLang's R-Fork parameters for consistency. -``` - -### Additional Comments Based on Actual Implementation - -#### File: `modelexpress_common/proto/p2p.proto` - -**Comment 20 - Line ~59 (transfer_engine_session_id field)** -``` -πŸ“ Format Documentation Needed - -The `transfer_engine_session_id` is described as "ip:port" format. Consider: - -1. **Validation**: Add format validation (e.g., regex or parsing) to ensure it's - a valid "ip:port" format -2. **Documentation**: Add example in comment: `// Format: "10.0.0.1:8000"` -3. **Future-proofing**: If you need additional TransferEngine connection parameters - later (e.g., authentication tokens, protocol version), consider using a structured - message instead of a string - -Current approach is fine for MVP, but structured message would be more extensible: -```protobuf -message TransferEngineBackendMetadata { - string seed_instance_ip = 1; - uint32 seed_instance_service_port = 2; - // Future: repeated uint32 send_weights_group_ports = 3; -} -``` -``` - -#### File: `modelexpress_server/src/metadata_backend.rs` - -**Comment 21 - Line ~43 (BackendMetadataRecord::None)** -``` -❓ Design Question: When is `None` valid? - -The `BackendMetadataRecord::None` variant suggests workers can exist without -backend metadata. Is this intentional? Consider: - -1. **Use case**: When would a worker have no backend metadata? Is this for - a specific deployment scenario? -2. **Validation**: Should `publish_metadata` reject workers with `None` backend? -3. **Documentation**: Add a comment explaining when `None` is acceptable - -If `None` is not a valid state, consider removing it and making the enum -non-optional, or add validation to reject it. -``` - -**Comment 22 - Line ~49-59 (from_flat priority logic)** -``` -βœ… Good: Priority logic is clear - -The priority logic (TransferEngine > NIXL > None) is reasonable. One enhancement: - -Consider logging when priority is applied to help with debugging: -```rust -pub fn from_flat(nixl_metadata: Vec, transfer_engine_session_id: Option) -> Self { - let has_nixl = !nixl_metadata.is_empty(); - let has_te = transfer_engine_session_id.as_ref() - .map(|s| !s.is_empty()) - .unwrap_or(false); - - if has_te && has_nixl { - tracing::debug!( - "Both NIXL and TransferEngine metadata present, using TransferEngine (priority)" - ); - } - - if let Some(sid) = transfer_engine_session_id - && !sid.is_empty() - { - return Self::TransferEngine(sid); - } - ... -} -``` -``` - -### Summary of Priority Comments - -**High Priority (Must Address)**: -- Comment 2: Backward compatibility handling (if old format still exists) -- Comment 8: Priority logic documentation and validation -- Comment 20: TransferEngine session ID format validation -- Comment 21: Clarify when `BackendMetadataRecord::None` is valid - -**Medium Priority (Should Address)**: -- Comment 5: Validation for empty backend metadata -- Comment 6: Add validation/warnings for empty metadata -- Comment 10: Factory pattern for transfer managers (client-side) -- Comment 12: Backend-aware routing in client -- Comment 16: Test coverage for all scenarios -- Comment 17: Documentation updates - -**Low Priority (Nice to Have)**: -- Comment 1: Enhanced documentation for TransferEngine session ID format -- Comment 9: Enhanced logging -- Comment 11: Backend availability detection with fallback -- Comment 19: Configuration documentation -- Comment 22: Enhanced logging for priority logic - ---- - -## Backend Selection Logic: TransferEngine vs NIXL - -### Current State - -Based on the codebase review, **the backend selection logic is not yet fully implemented**. Here's what I found: - -1. **Protocol Support**: The `p2p.proto` file supports both backends via `oneof`: - ```protobuf - oneof backend_metadata { - bytes nixl_metadata = 2; - string transfer_engine_session_id = 10; - } - ``` - -2. **Client Implementation**: Currently, the client code (`vllm_loader.py` line 338) **only sets NIXL metadata**: - ```python - worker = p2p_pb2.WorkerMetadata( - worker_rank=device_id, - nixl_metadata=nixl_metadata, # Only NIXL is set - tensors=tensor_protos, - ) - ``` - -3. **No Selection Logic**: There's no configuration or code that chooses between TransferEngine and NIXL. - -### How It Should Work - -The backend selection should happen at **two points**: - -#### 1. Source Side (When Publishing Metadata) - -The source decides which backend to use based on: -- **Configuration**: Environment variable or config file -- **Availability**: Runtime detection of which backends are available -- **User preference**: Explicit configuration - -**Recommended Implementation**: -```python -# In vllm_loader.py _publish_metadata_to_server() -def _publish_metadata_to_server(self, raw_tensors, device_id): - # Determine which backend to use - backend_type = self._select_backend() # NEW: Selection logic - - if backend_type == "transfer_engine": - # Initialize TransferEngine and get session ID - te_session_id = self._get_transfer_engine_session_id() - worker = p2p_pb2.WorkerMetadata( - worker_rank=device_id, - transfer_engine_session_id=te_session_id, # Set TE field - tensors=tensor_protos, - ) - else: # Default to NIXL - nixl_metadata = self._nixl_manager.nixl_metadata if self._nixl_manager else b"" - worker = p2p_pb2.WorkerMetadata( - worker_rank=device_id, - nixl_metadata=nixl_metadata, # Set NIXL field - tensors=tensor_protos, - ) - - self._mx_client.publish_metadata(model_name, [worker]) - -def _select_backend(self) -> str: - """Select backend based on configuration and availability.""" - # Check explicit configuration - configured_backend = os.environ.get("MX_TRANSFER_BACKEND", "nixl") - - if configured_backend == "transfer_engine": - if is_transfer_engine_available(): - return "transfer_engine" - else: - logger.warning("TransferEngine not available, falling back to NIXL") - return "nixl" - else: - return "nixl" # Default -``` - -#### 2. Target Side (When Receiving Metadata) - -The target must use **whatever backend the source published**. The target cannot choose - it must match the source's backend. - -**Recommended Implementation**: -```python -# In vllm_loader.py load_model() for target -def load_model(self, ...): - # Get metadata from server - metadata_response = self._mx_client.get_metadata(model_name) - - for worker in metadata_response.workers: - # Check which backend the source used - if worker.HasField("transfer_engine_session_id"): - # Source uses TransferEngine - if not is_transfer_engine_available(): - raise RuntimeError( - f"Source worker {worker.worker_rank} uses TransferEngine, " - "but TransferEngine is not available on this target" - ) - # Use TransferEngine to connect - self._connect_via_transfer_engine(worker.transfer_engine_session_id) - - elif worker.HasField("nixl_metadata"): - # Source uses NIXL - if not is_nixl_available(): - raise RuntimeError( - f"Source worker {worker.worker_rank} uses NIXL, " - "but NIXL is not available on this target" - ) - # Use NIXL to connect - self._connect_via_nixl(worker.nixl_metadata) - else: - raise RuntimeError("Source worker has no backend metadata") -``` - -### Configuration Options - -**Recommended Environment Variables**: - -1. **`MX_TRANSFER_BACKEND`**: Primary backend selection - - Values: `nixl` (default), `transfer_engine`, `auto` - - `auto`: Try TransferEngine first, fallback to NIXL - -2. **`MX_TRANSFER_ENGINE_ENABLED`**: Explicit enable/disable - - Values: `true`, `false` (default: `false`) - - Overrides `MX_TRANSFER_BACKEND` if set to `false` - -3. **Runtime Detection**: Check availability at runtime - ```python - def is_transfer_engine_available() -> bool: - try: - from transfer_engine import TransferEngine - return True - except ImportError: - return False - ``` - -### Priority/Precedence Rules - -1. **Source publishes with one backend** β†’ Target must use the same backend -2. **If source uses TransferEngine but target doesn't have it** β†’ Error (clear message) -3. **If source uses NIXL but target doesn't have it** β†’ Error (clear message) -4. **If both are available** β†’ Use source's choice (no negotiation) - -### Missing Implementation - -Based on the code review, the following is **missing**: - -1. βœ… **Proto support**: Already implemented (`oneof` pattern) -2. ❌ **Source selection logic**: Not implemented (always uses NIXL) -3. ❌ **Target routing logic**: Not implemented (always expects NIXL) -4. ❌ **TransferEngine client code**: Not implemented -5. ❌ **Configuration variables**: Not documented/implemented -6. ❌ **Availability detection**: Not implemented - -### Recommendation - -Add explicit backend selection logic to the client code: - -1. **Add configuration**: `MX_TRANSFER_BACKEND` environment variable -2. **Add selection method**: `_select_backend()` in `MxSourceModelLoader` -3. **Add routing method**: Check `HasField()` in `MxTargetModelLoader` -4. **Add TransferEngine manager**: Similar to `NixlTransferManager` -5. **Add tests**: Test both backends and mixed scenarios - -This ensures the backend selection is **explicit and configurable**, rather than implicit or hardcoded. diff --git a/docs/feedback_pr19920.md b/docs/feedback_pr19920.md deleted file mode 100644 index f97f9538..00000000 --- a/docs/feedback_pr19920.md +++ /dev/null @@ -1,863 +0,0 @@ -# PR 19920: [1/2] Add ModelExpress coordination for remote instance weight loading - matching TP - -## Executive Summary - -This document provides a design review and feedback for SGLang PR 19920, which adds ModelExpress coordination for remote instance weight loading. The PR integrates ModelExpress gRPC server as a coordination layer for TransferEngine-based weight transfers, replacing direct HTTP communication between seed and target instances. - -**Key Changes:** -- Adds `MODEL_EXPRESS` backend option for `remote_instance_weight_loader_backend` -- Integrates ModelExpress client for metadata coordination -- Supports TP rank matching between seed and target instances -- Uses TransferEngine for actual RDMA transfers (coordinated via ModelExpress) - -## Architecture Overview - -### Current Flow (Before PR 19920) - -**NCCL Backend:** -``` -Seed Instance β†’ Direct HTTP β†’ Target Instance - - Seed publishes TransferEngine session ID via HTTP endpoint - - Target queries seed HTTP endpoint for session ID - - Target connects directly to seed via TransferEngine -``` - -**TransferEngine Backend (Direct):** -``` -Seed Instance β†’ HTTP endpoint β†’ Target Instance - - Seed exposes /get_remote_instance_transfer_engine_info - - Target queries per-rank session IDs - - Direct TransferEngine connection -``` - -### New Flow (After PR 19920) - -**ModelExpress Backend:** -``` -Seed Instance β†’ ModelExpress Server β†’ Target Instance - - Seed publishes metadata to ModelExpress gRPC server - - Target queries ModelExpress for seed metadata - - ModelExpress coordinates ready state - - Target connects to seed via TransferEngine (using session ID from metadata) -``` - -## Implementation Review - -### 1. Protocol Buffer Integration - -#### βœ… **Good: Correct Use of `oneof` Pattern** - -**File**: `python/sglang/srt/model_loader/loader.py` (line ~2340) - -The implementation correctly uses the `oneof` pattern to extract TransferEngine session ID: - -```python -backend_field = source_worker.WhichOneof("backend_metadata") -if backend_field == "transfer_engine_session_id": - seed_session_id = source_worker.transfer_engine_session_id -else: - raise RuntimeError( - f"ModelExpress: expected transfer_engine_session_id, " - f"got backend_metadata={backend_field}" - ) -``` - -**Comment**: This correctly handles the `oneof` pattern from ModelExpress PR 157. Good error handling when the wrong backend type is present. - -#### ⚠️ **Concern: No Fallback for NIXL Backend** - -**Issue**: The code only handles `transfer_engine_session_id` and raises an error for other backends. What if the source uses NIXL backend? - -**Recommendation**: Add support for NIXL backend or provide a clear error message: - -```python -backend_field = source_worker.WhichOneof("backend_metadata") -if backend_field == "transfer_engine_session_id": - seed_session_id = source_worker.transfer_engine_session_id -elif backend_field == "nixl_metadata": - raise RuntimeError( - f"ModelExpress: source worker {tp_rank} uses NIXL backend, " - f"but MODEL_EXPRESS backend requires TransferEngine. " - f"Please use a source with TransferEngine backend or use NIXL directly." - ) -else: - raise RuntimeError( - f"ModelExpress: unknown backend_metadata={backend_field} " - f"for worker {tp_rank}" - ) -``` - -### 2. Source Side: Publishing Metadata - -#### βœ… **Good: Proper Metadata Publishing** - -**File**: `python/sglang/srt/model_executor/model_runner.py` (line ~680-750) - -The `_publish_model_express_metadata()` function: -- Correctly builds tensor descriptors from weight info -- Uses `transfer_engine_session_id` in the `oneof` field -- Publishes both metadata and ready flag -- Handles element size to dtype mapping for FP8 models - -**Comment**: The implementation correctly uses byte sizes (`numel * element_size`) for tensor descriptors, which is important for mixed-dtype models (FP8 + BF16). - -#### ⚠️ **Concern: Dtype Inference from Element Size** - -**File**: `python/sglang/srt/model_executor/model_runner.py` (line ~700) - -```python -element_size_to_dtype = {1: "float8_e4m3fn", 2: "bfloat16", 4: "float32", 8: "float64"} -``` - -**Issue**: This mapping is lossy. Multiple dtypes can have the same element size: -- Element size 2: `float16`, `bfloat16`, `int16`, `uint16` -- Element size 1: `int8`, `uint8`, `float8_e4m3fn`, `float8_e5m2` - -**Recommendation**: Use actual tensor dtype instead of inferring from element size: - -```python -tensors = [] -for name, (addr, numel, element_size) in weight_info.items(): - # Get actual tensor to determine dtype - tensor = dict(model.named_parameters())[name] - dtype_str = str(tensor.dtype).replace("torch.", "") - - tensors.append(p2p_pb2.TensorDescriptor( - name=name, - addr=addr, - size=numel * element_size, - device_id=self.gpu_id, - dtype=dtype_str, # Use actual dtype - )) -``` - -**Alternative**: If weight_info doesn't include tensor references, add dtype to the weight_info tuple: -```python -# In register_memory_region, return (addr, numel, element_size, dtype_str) -weight_info[name] = (addr, numel, element_size, str(tensor.dtype).replace("torch.", "")) -``` - -#### βœ… **Good: TP Rank Matching** - -**File**: `python/sglang/srt/model_loader/loader.py` (line ~2310) - -The code correctly matches TP ranks: -```python -for w in response.workers: - if w.worker_rank == tp_rank: - source_worker = w - break -``` - -This ensures each target TP rank connects to the corresponding seed TP rank, which is critical for tensor parallelism. - -### 3. Target Side: Loading Weights - -#### βœ… **Good: Byte Size Matching** - -**File**: `python/sglang/srt/model_loader/loader.py` (line ~2370) - -The code correctly uses byte sizes for matching: -```python -seed_ptr, seed_size = weight_info -local_size = tensor.numel() * tensor.element_size() -if seed_size != local_size: - raise RuntimeError(...) -``` - -**Comment**: This is correct! RDMA is a memcpy operation, so byte size matching is sufficient. Dtype differences (e.g., FP8 vs BF16) are handled by the model's quantization logic, not the transfer layer. - -#### ⚠️ **Concern: Missing Tensor Name Validation** - -**Issue**: The code assumes tensor names match exactly between seed and target. What if: -- Model architectures differ slightly? -- Tensor names have different prefixes? -- Some tensors are missing? - -**Recommendation**: Add more robust matching: - -```python -for name, tensor in model.named_parameters(): - weight_info = seed_weight_info.get(name, None) - if weight_info is None: - # Try fuzzy matching or provide helpful error - logger.warning( - f"ModelExpress: tensor '{name}' not found in seed metadata. " - f"Available tensors: {list(seed_weight_info.keys())[:10]}..." - ) - raise RuntimeError( - f"ModelExpress: cannot find weight info for {name} " - f"in seed metadata. This may indicate a model architecture mismatch." - ) -``` - -#### βœ… **Good: Ready State Coordination** - -**File**: `python/sglang/srt/model_loader/loader.py` (line ~2280) - -The code correctly waits for seed ready state: -```python -ready, session_id, metadata_hash = mx_client.wait_for_ready( - model_name, worker_id=tp_rank, -) -``` - -This ensures the target doesn't start transferring before the seed is fully initialized and stable. - -### 4. Configuration & CLI Arguments - -#### βœ… **Good: Clear CLI Arguments** - -**File**: `python/sglang/srt/server_args.py` - -The PR adds three new CLI arguments: -- `--model-express-url`: ModelExpress server URL -- `--model-express-model-name`: Model name for coordination -- `--model-express-source`: Flag to run as seed source - -**Comment**: The arguments are well-named and follow SGLang's existing patterns. - -#### ⚠️ **Concern: Validation Logic** - -**File**: `python/sglang/srt/server_args.py` (line ~2722) - -```python -if self.remote_instance_weight_loader_backend == "model_express": - if self.model_express_url is None: - logger.warning("Fallback load_format to 'auto'...") - self.load_format = "auto" -``` - -**Issue**: The validation silently falls back to `auto` instead of raising an error. This could lead to confusion. - -**Recommendation**: Make validation stricter or provide clearer messaging: - -```python -if self.remote_instance_weight_loader_backend == "model_express": - if self.model_express_url is None: - raise ValueError( - "--model-express-url is required when using " - "--remote-instance-weight-loader-backend=model_express" - ) - if not self.validate_transfer_engine(): - raise ValueError( - "TransferEngine is required for model_express backend. " - "Please install mooncake.engine or use a different backend." - ) -``` - -#### ⚠️ **Concern: Model Name Default** - -**File**: `python/sglang/srt/model_executor/model_runner.py` (line ~685) - -```python -model_name = ( - self.server_args.model_express_model_name - or self.server_args.model_path -) -``` - -**Issue**: Using `model_path` as default could lead to inconsistent model names (e.g., `/path/to/model` vs `meta-llama/Llama-3.1-70B`). - -**Recommendation**: Use a more consistent default or require explicit model name: - -```python -model_name = self.server_args.model_express_model_name -if not model_name: - # Extract model name from model_path (e.g., last component) - model_name = os.path.basename(self.server_args.model_path.rstrip('/')) - logger.warning( - f"ModelExpress: using model_name='{model_name}' from model_path. " - f"Consider setting --model-express-model-name explicitly." - ) -``` - -### 5. Error Handling - -#### βœ… **Good: Comprehensive Error Messages** - -The code provides clear error messages for common failure modes: -- Missing metadata -- Worker rank mismatch -- Size mismatches -- TransferEngine failures - -#### ⚠️ **Concern: Timeout Handling** - -**File**: `python/sglang/srt/model_loader/loader.py` (line ~2280) - -```python -ready, session_id, metadata_hash = mx_client.wait_for_ready( - model_name, worker_id=tp_rank, -) -if not ready: - raise RuntimeError("ModelExpress: timed out waiting for seed ready...") -``` - -**Issue**: The timeout is not configurable and may not be visible in the error message. - -**Recommendation**: Add timeout parameter and include it in error: - -```python -timeout_seconds = load_config.model_express_ready_timeout or 7200 # 2 hours default -ready, session_id, metadata_hash = mx_client.wait_for_ready( - model_name, worker_id=tp_rank, timeout_seconds=timeout_seconds, -) -if not ready: - raise RuntimeError( - f"ModelExpress: timed out waiting for seed ready " - f"(model={model_name}, worker={tp_rank}, timeout={timeout_seconds}s)" - ) -``` - -### 6. Integration with TransferEngine - -#### βœ… **Good: Reuses Existing TransferEngine Infrastructure** - -The PR correctly reuses: -- `register_memory_region()` for memory registration -- `batch_transfer_sync_read()` for RDMA transfers -- Existing TransferEngine initialization logic - -**Comment**: This is a clean integration that doesn't duplicate code. - -#### ⚠️ **Concern: TransferEngine Initialization Timing** - -**File**: `python/sglang/srt/model_executor/model_runner.py` (line ~1075) - -For seed sources, TransferEngine weight info is registered in `model_specific_adjustment()`: - -```python -if self.server_args.model_express_source: - if self.remote_instance_transfer_engine_weight_info is None: - self.remote_instance_transfer_engine_weight_info = ( - register_memory_region(self.model, self.remote_instance_transfer_engine) - ) - self._publish_model_express_metadata() -``` - -**Issue**: This happens after model loading. If the model is loaded via `DefaultModelLoader` (load_format=auto), the weights may have been processed/quantized, which could affect memory addresses. - -**Recommendation**: Document this timing and ensure weights are stable before registration: - -```python -# Ensure model weights are finalized before registering -# (post_load_weights may modify weights) -if hasattr(self.model, "post_load_weights"): - self.model.post_load_weights() - -# Now register memory regions (weights are stable) -if self.server_args.model_express_source: - ... -``` - -### 7. Testing & Edge Cases - -#### ❓ **Missing: Test Coverage** - -**Questions**: -1. Are there unit tests for `load_model_from_model_express()`? -2. Are there integration tests for the full flow (seed β†’ ModelExpress β†’ target)? -3. How is TP rank mismatch handled? -4. What happens if seed and target have different TP sizes? - -**Recommendation**: Add tests for: -- TP rank matching logic -- Byte size validation -- Missing tensor handling -- ModelExpress server unavailability -- Timeout scenarios - -### 8. Documentation - -#### ⚠️ **Missing: Usage Documentation** - -**Recommendation**: Add documentation explaining: -1. How to set up ModelExpress server -2. How to run seed instance with `--model-express-source` -3. How to run target instance with `--remote-instance-weight-loader-backend=model_express` -4. Model name coordination requirements -5. TP rank matching requirements - -**Example**: -```markdown -## ModelExpress Remote Instance Loading - -### Setup - -1. Start ModelExpress server: - ```bash - modelexpress-server --port 8001 - ``` - -2. Start seed instance: - ```bash - python -m sglang.launch_server \ - --model-path meta-llama/Llama-3.1-70B \ - --model-express-url localhost:8001 \ - --model-express-model-name meta-llama/Llama-3.1-70B \ - --model-express-source \ - --remote-instance-weight-loader-start-seed-via-transfer-engine - ``` - -3. Start target instance: - ```bash - python -m sglang.launch_server \ - --model-path meta-llama/Llama-3.1-70B \ - --load-format remote_instance \ - --remote-instance-weight-loader-backend model_express \ - --model-express-url localhost:8001 \ - --model-express-model-name meta-llama/Llama-3.1-70B - ``` - -### Requirements - -- Seed and target must have **matching TP sizes** (e.g., both TP=8) -- Each target TP rank connects to the corresponding seed TP rank -- ModelExpress server must be accessible from both instances -- TransferEngine must be initialized on both instances -``` - -## Specific PR Review Comments - -### High Priority - -1. **Dtype Inference**: Fix dtype mapping to use actual tensor dtypes instead of element size (see Section 2) -2. **NIXL Backend Support**: Add error handling for NIXL backend case (see Section 1) -3. **Validation**: Make CLI argument validation stricter (see Section 4) -4. **Model Name Default**: Improve model name default logic (see Section 4) - -### Medium Priority - -5. **Tensor Name Matching**: Add more robust tensor name matching with better error messages (see Section 3) -6. **Timeout Configuration**: Make timeout configurable and visible in errors (see Section 5) -7. **Memory Registration Timing**: Document/ensure weights are stable before registration (see Section 6) -8. **Documentation**: Add usage documentation (see Section 8) - -### Low Priority - -9. **Test Coverage**: Add comprehensive tests (see Section 7) -10. **Logging**: Add more detailed logging for debugging -11. **Error Recovery**: Consider retry logic for transient ModelExpress errors - -## Alignment with ModelExpress PR 157 - -### βœ… **Correct Integration** - -The SGLang PR correctly uses the `oneof` pattern from ModelExpress PR 157: -- Extracts `transfer_engine_session_id` from `backend_metadata` oneof -- Uses `WhichOneof()` to check backend type -- Provides appropriate error handling - -### ⚠️ **Missing: Backend Selection** - -The SGLang PR assumes TransferEngine backend. It doesn't: -- Check if source uses NIXL backend -- Provide fallback to NIXL if TransferEngine unavailable -- Allow configuration of preferred backend - -**Recommendation**: Consider adding backend selection logic similar to what was discussed in ModelExpress PR 157 feedback. - -## Conclusion - -PR 19920 provides a solid integration of ModelExpress coordination for remote instance weight loading. The implementation correctly: - -1. βœ… Uses the `oneof` pattern from ModelExpress PR 157 -2. βœ… Implements TP rank matching -3. βœ… Handles byte-size matching for mixed-dtype models -4. βœ… Coordinates ready state via ModelExpress - -**Key Improvements Needed**: -1. Fix dtype inference to use actual tensor dtypes -2. Add NIXL backend error handling -3. Improve validation and error messages -4. Add comprehensive documentation and tests - -The PR is well-structured and follows SGLang's existing patterns. With the suggested improvements, it will provide a robust foundation for ModelExpress-coordinated weight loading. - ---- - -## PR Review Comments - -This section provides specific comments to make directly on PR 19920, organized by file and line numbers. These comments should be added as inline code review comments on the PR. - -### File: `python/sglang/srt/model_loader/loader.py` - -**Comment 1 - Line ~2340 (load_model_from_model_express, backend_field check)** -``` -⚠️ Backend Type Handling: Add support for NIXL backend error case - -Currently, the code only handles `transfer_engine_session_id` and raises a generic error for other backends. Consider adding explicit handling for NIXL: - -```python -backend_field = source_worker.WhichOneof("backend_metadata") -if backend_field == "transfer_engine_session_id": - seed_session_id = source_worker.transfer_engine_session_id -elif backend_field == "nixl_metadata": - raise RuntimeError( - f"ModelExpress: source worker {tp_rank} uses NIXL backend, " - f"but MODEL_EXPRESS backend requires TransferEngine. " - f"Please use a source with TransferEngine backend or use NIXL directly." - ) -else: - raise RuntimeError( - f"ModelExpress: unknown backend_metadata={backend_field} " - f"for worker {tp_rank}. Expected 'transfer_engine_session_id'." - ) -``` - -This provides clearer error messages when backend types don't match. -``` - -**Comment 2 - Line ~2350 (tensor descriptor conversion)** -``` -βœ… Good: Byte size matching approach - -The use of raw byte sizes (`td.size`) for matching is correct for RDMA transfers. RDMA is a memcpy operation, so byte-level matching is appropriate regardless of dtype differences (FP8 vs BF16, etc.). - -Consider adding a comment explaining this: -```python -# Convert tensor descriptors to {name: (addr, size_bytes)} format -# Use raw byte sizes -- RDMA is a memcpy, dtype matching is not required -# The model's quantization logic handles dtype conversions, not the transfer layer -seed_weight_info = {} -``` -``` - -**Comment 3 - Line ~2370 (tensor name matching)** -``` -⚠️ Error Message Enhancement: Improve missing tensor error - -When a tensor name is not found, provide more context: - -```python -for name, tensor in model.named_parameters(): - weight_info = seed_weight_info.get(name, None) - if weight_info is None: - # Provide helpful context - available_names = list(seed_weight_info.keys()) - logger.error( - f"ModelExpress: tensor '{name}' not found in seed metadata. " - f"Available tensors ({len(available_names)}): {available_names[:5]}..." - ) - raise RuntimeError( - f"ModelExpress: cannot find weight info for '{name}' " - f"in seed metadata. This may indicate a model architecture mismatch " - f"or different model versions between seed and target." - ) -``` - -This helps debug model architecture mismatches. -``` - -**Comment 4 - Line ~2280 (wait_for_ready call)** -``` -⚠️ Timeout Configuration: Make timeout configurable - -The `wait_for_ready` timeout is not visible in the code. Consider: - -```python -timeout_seconds = getattr(load_config, 'model_express_ready_timeout', 7200) # 2 hours default -ready, session_id, metadata_hash = mx_client.wait_for_ready( - model_name, worker_id=tp_rank, timeout_seconds=timeout_seconds, -) -if not ready: - raise RuntimeError( - f"ModelExpress: timed out waiting for seed ready " - f"(model={model_name}, worker={tp_rank}, timeout={timeout_seconds}s). " - f"Check that seed instance is running and has published ready flag." - ) -``` - -Also consider adding `model_express_ready_timeout` to LoadConfig and ServerArgs. -``` - -### File: `python/sglang/srt/model_executor/model_runner.py` - -**Comment 5 - Line ~700 (_publish_model_express_metadata, dtype inference)** -``` -πŸ”§ Critical: Fix dtype inference from element size - -The current mapping is lossy and can misidentify dtypes: - -```python -element_size_to_dtype = {1: "float8_e4m3fn", 2: "bfloat16", 4: "float32", 8: "float64"} -``` - -**Problem**: Multiple dtypes share the same element size: -- Size 2: `float16`, `bfloat16`, `int16`, `uint16` -- Size 1: `int8`, `uint8`, `float8_e4m3fn`, `float8_e5m2` - -**Solution**: Use actual tensor dtype: - -```python -tensors = [] -for name, (addr, numel, element_size) in weight_info.items(): - # Get actual tensor to determine dtype - param_dict = dict(self.model.named_parameters()) - if name not in param_dict: - logger.warning(f"Parameter {name} not found in model, using element_size inference") - dtype_str = element_size_to_dtype.get(element_size, "unknown") - else: - tensor = param_dict[name] - dtype_str = str(tensor.dtype).replace("torch.", "") - - tensors.append(p2p_pb2.TensorDescriptor( - name=name, - addr=addr, - size=numel * element_size, - device_id=self.gpu_id, - dtype=dtype_str, - )) -``` - -**Alternative**: Modify `register_memory_region` to return dtype as well: -```python -# In remote_instance_weight_loader_utils.py -weight_info[name] = (addr, numel, element_size, str(tensor.dtype).replace("torch.", "")) -``` -``` - -**Comment 6 - Line ~685 (model_name default)** -``` -⚠️ Model Name Default: Improve consistency - -Using `model_path` as default can lead to inconsistent model names: - -```python -model_name = ( - self.server_args.model_express_model_name - or self.server_args.model_path -) -``` - -**Issue**: `model_path` might be `/path/to/model` while target uses `meta-llama/Llama-3.1-70B`. - -**Recommendation**: -```python -model_name = self.server_args.model_express_model_name -if not model_name: - # Extract model name from model_path (last component) - import os - model_name = os.path.basename(self.server_args.model_path.rstrip('/')) - logger.warning( - f"ModelExpress: using model_name='{model_name}' from model_path. " - f"Consider setting --model-express-model-name explicitly for consistency." - ) -``` - -Or require explicit model name: -```python -if not self.server_args.model_express_model_name: - raise ValueError( - "--model-express-model-name is required when using --model-express-source" - ) -``` -``` - -**Comment 7 - Line ~1075 (model_specific_adjustment, memory registration timing)** -``` -⚠️ Memory Registration Timing: Ensure weights are stable - -The memory registration happens after model loading, but weights may be modified by `post_load_weights()`. Consider: - -```python -# In model_specific_adjustment(), before ModelExpress publish: -# Ensure model weights are finalized (post_load_weights may modify weights) -if hasattr(self.model, "post_load_weights"): - self.model.post_load_weights() - -# Now register memory regions (weights are stable) -if self.server_args.model_express_source: - if ( - self.remote_instance_transfer_engine_weight_info is None - and self.remote_instance_transfer_engine is not None - ): - self.remote_instance_transfer_engine_weight_info = ( - register_memory_region(self.model, self.remote_instance_transfer_engine) - ) - self._publish_model_express_metadata() -``` - -This ensures memory addresses remain valid after registration. -``` - -**Comment 8 - Line ~720 (publish_ready call)** -``` -πŸ“ Metadata Hash: Consider computing actual hash - -Currently, `metadata_hash` is set to empty string: - -```python -mx_client.publish_ready( - model_name, - worker_id=self.tp_rank, - session_id=mx_client.session_id, - metadata_hash="", # Empty hash -) -``` - -Consider computing an actual hash of the tensor descriptors for validation: - -```python -import hashlib -metadata_str = ",".join(sorted(f"{td.name}:{td.addr}:{td.size}" for td in tensors)) -metadata_hash = hashlib.md5(metadata_str.encode()).hexdigest() - -mx_client.publish_ready( - model_name, - worker_id=self.tp_rank, - session_id=mx_client.session_id, - metadata_hash=metadata_hash, -) -``` - -This enables target-side validation that metadata hasn't changed. -``` - -### File: `python/sglang/srt/server_args.py` - -**Comment 9 - Line ~2722 (validation logic)** -``` -⚠️ Validation: Make validation stricter - -The current validation silently falls back to `auto`: - -```python -if self.remote_instance_weight_loader_backend == "model_express": - if self.model_express_url is None: - logger.warning("Fallback load_format to 'auto'...") - self.load_format = "auto" -``` - -**Recommendation**: Raise an error instead: - -```python -if self.remote_instance_weight_loader_backend == "model_express": - if self.model_express_url is None: - raise ValueError( - "--model-express-url is required when using " - "--remote-instance-weight-loader-backend=model_express" - ) - if not self.validate_transfer_engine(): - raise ValueError( - "TransferEngine is required for model_express backend. " - "Please install mooncake.engine or use a different backend." - ) -``` - -Silent fallback can lead to confusion when users expect model_express backend. -``` - -**Comment 10 - Line ~5235 (CLI argument help text)** -``` -πŸ“ Documentation: Enhance help text - -The help text for `--model-express-source` could be more descriptive: - -```python -parser.add_argument( - "--model-express-source", - action="store_true", - help=( - "Run as a ModelExpress seed source: publish TransferEngine metadata " - "to the ModelExpress server after loading weights. " - "Requires --model-express-url and TransferEngine initialization. " - "Target instances can then load weights via --remote-instance-weight-loader-backend=model_express." - ), -) -``` - -This clarifies the relationship between source and target modes. -``` - -**Comment 11 - Line ~5783 (validate_transfer_engine, ModelExpress source check)** -``` -βœ… Good: TransferEngine validation includes ModelExpress source - -The validation correctly checks for ModelExpress source mode: - -```python -if self.model_express_source: - return True -``` - -This ensures TransferEngine is initialized when running as a seed source. -``` - -### File: `python/sglang/srt/configs/load_config.py` - -**Comment 12 - Line ~78-79 (LoadConfig fields)** -``` -βœ… Good: Clean addition of ModelExpress fields - -The addition of `model_express_url` and `model_express_model_name` to LoadConfig is clean and follows existing patterns. - -Consider adding a comment: -```python -# ModelExpress coordination fields (for remote_instance_weight_loader_backend=model_express) -model_express_url: Optional[str] = None -model_express_model_name: Optional[str] = None -``` -``` - -### Testing & Documentation - -**Comment 13 - Missing: Test Coverage** -``` -βœ… Test Coverage Needed - -Please add tests for: -1. **TP rank matching**: Verify each target rank connects to correct seed rank -2. **Byte size validation**: Test size mismatch detection -3. **Missing tensor handling**: Test behavior when tensor names don't match -4. **ModelExpress server unavailability**: Test error handling -5. **Timeout scenarios**: Test ready state timeout handling -6. **Mixed dtype models**: Test FP8 + BF16 models - -Example test structure: -```python -def test_model_express_tp_rank_matching(): - # Test that target TP rank 0 connects to seed TP rank 0 - ... - -def test_model_express_byte_size_validation(): - # Test that size mismatches are detected - ... -``` -``` - -**Comment 14 - Missing: Usage Documentation** -``` -πŸ“š Documentation Needed - -Please add documentation explaining: -1. How to set up ModelExpress server -2. How to run seed instance with `--model-express-source` -3. How to run target instance with `--remote-instance-weight-loader-backend=model_express` -4. Model name coordination requirements -5. TP rank matching requirements (seed and target must have same TP size) - -Consider adding to `docs/advanced_features/rfork.md` or creating a new section. -``` - -### Summary of Priority Comments - -**High Priority (Must Address)**: -- Comment 5: Fix dtype inference from element size (critical for correctness) -- Comment 9: Make validation stricter (prevents silent failures) -- Comment 6: Improve model name default logic (prevents coordination failures) - -**Medium Priority (Should Address)**: -- Comment 1: Add NIXL backend error handling -- Comment 3: Improve missing tensor error messages -- Comment 4: Make timeout configurable -- Comment 7: Ensure weights are stable before registration -- Comment 13: Add test coverage - -**Low Priority (Nice to Have)**: -- Comment 2: Add comment explaining byte size matching -- Comment 8: Compute actual metadata hash -- Comment 10: Enhance help text -- Comment 12: Add comments to LoadConfig -- Comment 14: Add usage documentation From 5d8ce71e23d13521ad083c6e1caeaa2c75ee7104 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Tue, 14 Apr 2026 13:17:13 -0700 Subject: [PATCH 10/25] docs: add MX RL integration overview (PRIME-RL + verl design) Made-with: Cursor Signed-off-by: Kavin Krishnan --- docs/MX_RL_OVERVIEW.md | 270 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 270 insertions(+) create mode 100644 docs/MX_RL_OVERVIEW.md diff --git a/docs/MX_RL_OVERVIEW.md b/docs/MX_RL_OVERVIEW.md new file mode 100644 index 00000000..938bfc05 --- /dev/null +++ b/docs/MX_RL_OVERVIEW.md @@ -0,0 +1,270 @@ +# ModelExpress for RL Post-Training β€” Design Overview + +**Last Updated**: April 2026 + +This document explains how ModelExpress (MX) accelerates reinforcement learning (RL) post-training by replacing slow weight transfer mechanisms with GPU-to-GPU RDMA. It covers the general design, the PRIME-RL proof of concept (working), and the verl integration plan. + +--- + +## The Problem: Weight Sync in RL + +RL post-training runs a continuous loop: generate text (inference) β†’ score it β†’ update the model (training) β†’ sync the updated weights back to inference β†’ repeat. The weight sync step is a bottleneck: + +| Method | How | Time (3 GB model) | Limitation | +|--------|-----|-------------------|-----------| +| Filesystem | Serialize to disk β†’ read back | 30-60s | Disk I/O, serialization | +| NCCL broadcast | Collective over network | 2-8s | Requires static groups, `max_async_level=1` | +| **MX + NIXL RDMA** | **GPU reads directly from GPU** | **<1s** | **None for async training** | + +ModelExpress eliminates the serialization-to-disk bottleneck while preserving async training capability. NCCL forces synchronous operation; the filesystem is slow. MX + NIXL gives both speed and flexibility. + +--- + +## How ModelExpress Works for RL + +### Architecture + +``` +Trainer GPU MX Server (gRPC + Redis) Inference GPU + β”‚ β”‚ β”‚ + β”‚ 1. optimizer.step() β”‚ β”‚ + β”‚ (weights updated in VRAM) β”‚ β”‚ + β”‚ β”‚ β”‚ + β”‚ 2. publish_weights() β”‚ β”‚ + │──── tensor addrs + NIXL ──────►│ β”‚ + β”‚ metadata via gRPC β”‚ β”‚ + β”‚ β”‚ 3. poll_for_source() β”‚ + β”‚ │◄──── "any new weights?" ───────────│ + β”‚ β”‚ β”‚ + β”‚ β”‚ 4. get_metadata() β”‚ + β”‚ │──── addrs + NIXL conn info ───────►│ + β”‚ β”‚ β”‚ + β”‚ 5. NIXL RDMA READ β”‚ β”‚ + │◄══════════════ GPU-to-GPU data transfer ═══════════════════════════►│ + β”‚ (inference GPU reads from trainer GPU, CPU not involved) β”‚ + β”‚ β”‚ β”‚ + β”‚ β”‚ 6. model.load_weights() β”‚ + β”‚ β”‚ (inference applies weights) β”‚ +``` + +**MX Server** stores only metadata β€” tensor names, GPU memory addresses, NIXL agent connection info, version tracking. It never touches weight data. The bulk transfer is a one-sided RDMA read between GPUs. + +### Client Library + +Two classes in `modelexpress_client/python/modelexpress/`: + +**`MxTrainingPublisher`** (trainer side): +```python +publisher = MxTrainingPublisher("trainer-rank-0", device_id=0, mx_server_url="mx-server:8001") +publisher.initialize(model_name="Qwen/Qwen2.5-1.5B") + +# After optimizer.step(): +publisher.publish_weights(model.state_dict(), step=training_step) +publisher.mark_ready() +``` + +- Registers GPU tensors with NIXL (once, on first call β€” addresses are stable across steps) +- Publishes tensor metadata + NIXL agent info to MX Server via gRPC +- Marks version as READY so inference can discover it + +**`MxRefitReceiver`** (inference side): +```python +receiver = MxRefitReceiver("inference-rank-0", device_id=0, mx_server_url="mx-server:8001") +receiver.initialize() + +source = receiver.poll_for_source(model_name="Qwen/Qwen2.5-1.5B") +for name, tensor in receiver.receive_weights_scratch(source): + ... # feed into model.load_weights() +``` + +- Queries MX Server for available weight sources +- Allocates scratch GPU buffers matching the source tensor layout +- NIXL RDMA reads weight data directly from the trainer's GPU +- Yields `(name, tensor)` pairs for the inference engine's weight loader + +### The Scratch-Buffer Approach + +RL trainers publish weights in HuggingFace format (339 separate tensors for Qwen2.5-1.5B). Inference engines like vLLM use fused tensors internally (198 parameters β€” Q/K/V merged into `qkv_proj`, gate/up merged into `gate_up_proj`). Names and shapes don't match. + +Solution: RDMA into temporary scratch buffers matching the trainer's layout, then feed through `model.load_weights()` which handles the name mapping and tensor fusion. The RDMA layer stays simple (just move bytes); the inference engine handles semantics. + +--- + +## PRIME-RL POC (Working) + +### What is PRIME-RL? + +An async-first RL framework by PrimeIntellect with three separate processes: +- **Trainer** β€” FSDP2 training (GPU pod) +- **Orchestrator** β€” Coordination, scoring (CPU pod) +- **Inference** β€” vLLM rollout generation (GPU pod) + +No Ray dependency. Raw Kubernetes pods. + +### Integration Summary + +| Item | Details | +|------|---------| +| **Repo** | `github.com/KavinKrishnan/prime-rl`, branch `kavink/mx-weight-broadcast` | +| **MX client branch** | `github.com/ai-dynamo/modelexpress`, branch `kavink/RL` | +| **New files in PRIME-RL** | `broadcast/modelexpress.py` (trainer), `worker/modelexpress.py` (inference), `Dockerfile.mx-arm64` | +| **Modified files** | 8 files, 93 lines added (configs, routes, factory, client helper) | +| **Cluster** | GKE DGXCloud, GB200 ARM64, 4 GPUs/node, RoCE networking | +| **Model** | Qwen/Qwen2.5-1.5B (3.55 GB BF16) | + +### How It Works + +1. Trainer runs `optimizer.step()`, gathers FSDP2 shards, calls `MxTrainingPublisher.publish_weights()` +2. Orchestrator detects new weights (via filesystem marker), tells inference to update +3. Inference calls `MxRefitReceiver.receive_weights_scratch()` β€” NIXL RDMA pulls weights from trainer GPU +4. Scratch tensors reshaped using safetensors header shapes, fed through `model.load_weights()` +5. vLLM resumes serving with updated model + +### Results + +| Metric | Filesystem | RDMA | Speedup | +|--------|-----------|------|---------| +| Weight update time | ~55s | <1s | **55x** | +| Transfer bandwidth | ~60 MB/s | 261-330 Gbps | ~500x | +| CPU involvement | Full | None | Eliminated | + +### Key Issues Resolved + +| Issue | Root Cause | Fix | +|-------|-----------|-----| +| `NIXL_ERR_NOT_ALLOWED` | Wrong `UCX_TLS` value | Match TRT-LLM: `self,sm,rc,cuda_copy,gdr_copy,tcp` | +| 800 KB metadata blob | Re-registering tensors every step | Register once, reuse cached metadata | +| `REMOTE_DISCONNECT` | UCX 1.18 + missing IMEX channels | UCX 1.20 + DRA `compute-domain-channel` claims | +| Tensor name mismatch | HF names vs vLLM fused names | Scratch-buffer approach + `model.load_weights()` | +| Shape assertion | 1D scratch vs 2D expected | Read shapes from safetensors header | + +### Cluster Config (GCP GB200) + +```yaml +hostNetwork: true +privileged: true +resourceClaims: + - name: compute-domain-channel + resourceClaimTemplateName: kavin-compute-domain-channel +env: + UCX_TLS: "self,sm,rc,cuda_copy,gdr_copy,tcp" + UCX_IB_GID_INDEX: "3" + OMPI_MCA_pml: "ob1" +volumes: + - /dev/infiniband (hostPath) +``` + +UCX 1.20+, IMEX channels via DRA, and `/dev/infiniband` are all required for cross-node GPU RDMA. + +--- + +## verl Integration (In Progress) + +### What is verl? + +A production-grade RL framework by ByteDance that uses Ray for orchestration. Supports FSDP, Megatron, vLLM, SGLang, TRT-LLM. Has a `CheckpointEngine` plugin system for weight transfer. + +### Why It's Easier Than PRIME-RL + +verl already has: +- **`CheckpointEngine` ABC** with `send_weights` / `receive_weights` β€” just implement a new backend +- **Existing NIXL engine** (`nixl_checkpoint_engine.py`) as a reference implementation +- **Bucketed transfers** that preserve tensor names and shapes β€” no scratch-buffer approach needed +- **`@CheckpointEngineRegistry.register("mx")`** β€” one decorator to plug in + +### Integration Summary + +| Item | Details | +|------|---------| +| **Repo** | `github.com/KavinKrishnan/verl`, branch `kavink/mx-checkpoint-engine` | +| **New file** | `verl/checkpoint_engine/mx_checkpoint_engine.py` (461 lines) | +| **Modified files** | 2 files, 8 lines (imports + config comment) | +| **Total** | 469 new lines, 2 modified | + +### `MxCheckpointEngine` Design + +Registered as `@CheckpointEngineRegistry.register("mx")`. Implements the `CheckpointEngine` ABC: + +- **`prepare()`** β€” Allocate send/recv GPU buckets, register with NIXL, create MX client +- **`build_topology()`** β€” Trainer rank 0 is the source, all rollout ranks connect via MX Server +- **`init_process_group()`** β€” Trainer adds rollout agents; rollouts add trainer agent +- **`send_weights()`** β€” Pack tensors into GPU buckets, send bucket metadata via ZMQ, make available for RDMA read +- **`receive_weights()`** β€” Receive metadata via ZMQ, NIXL RDMA read from trainer's bucket, yield `(name, tensor)` pairs +- **`finalize()`** β€” Cleanup connections, deregister memory + +Uses a **star topology** (trainer β†’ all rollouts) instead of the NIXL engine's ring. The MX Server enables future pipeline replication where rollouts become sources. + +### Config + +```yaml +actor_rollout_ref: + rollout: + checkpoint_engine: + backend: "mx" + engine_kwargs: + mx_server_url: "modelexpress-server:8001" + model_name: "Qwen/Qwen2.5-1.5B" +``` + +--- + +## Comparison: PRIME-RL vs verl Integration + +| Aspect | PRIME-RL | verl | +|--------|---------|------| +| Plugin system | Custom broadcast/worker extensions | `CheckpointEngine` registry | +| Existing NIXL support | None (built from scratch) | Full NIXL engine as reference | +| Lines of code | ~350 new + 93 modified | ~460 new + 8 modified | +| Ray dependency | None (raw K8s pods) | Ray actors, placement groups | +| Weight format | HF names β†’ scratch buffers β†’ `load_weights()` | Bucketed with shapes preserved | +| Tensor shape issue | Required safetensors header reading | Not an issue (bucket metadata carries shapes) | +| Colocated mode | N/A (always disaggregated) | `naive` engine handles colocated; MX for disaggregated | + +--- + +## Future: Pipeline Replication + +Current design uses a star topology (trainer β†’ all rollouts). At scale, the trainer's NIC becomes a bottleneck. [TensorHub](https://arxiv.org/abs/2604.09107v1) (ByteDance, April 2026) demonstrates **pipeline replication**: after a rollout receives weights, it publishes itself as a source. New rollouts pull from the nearest/least-loaded replica, creating a bandwidth-amplifying DAG. + +MX Server already supports multiple sources per model β€” implementing pipeline replication is a client-side change: +1. After RDMA receive, rollout calls `publish_weights()` on MX Server +2. New rollouts call `poll_for_source()` which returns the nearest available replica +3. MX Server load-balances across all replicas + +This is prioritized as P1 in our roadmap. + +--- + +## Repository Map + +### ModelExpress client (`kavink/RL` branch) + +``` +modelexpress_client/python/modelexpress/ +β”œβ”€β”€ training_publisher.py # MxTrainingPublisher β€” trainer-side publish +β”œβ”€β”€ refit_receiver.py # MxRefitReceiver β€” inference-side RDMA receive +β”œβ”€β”€ nixl_transfer.py # NixlTransferManager β€” NIXL agent lifecycle +β”œβ”€β”€ client.py # MxClient β€” gRPC client to MX Server +└── __init__.py # Exports MxTrainingPublisher, MxRefitReceiver +``` + +### PRIME-RL integration (`kavink/mx-weight-broadcast` branch) + +``` +src/prime_rl/ +β”œβ”€β”€ trainer/rl/broadcast/modelexpress.py # ModelExpressWeightBroadcast +β”œβ”€β”€ inference/vllm/worker/modelexpress.py # MxWeightUpdateWorker +β”œβ”€β”€ inference/vllm/server.py # /init_mx_broadcaster route (+13 lines) +β”œβ”€β”€ orchestrator/orchestrator.py # elif "modelexpress" branch (+7 lines) +β”œβ”€β”€ utils/client.py # init_mx_broadcast() (+34 lines) +└── configs/ # MxWeightBroadcastConfig (+32 lines) +``` + +### verl integration (`kavink/mx-checkpoint-engine` branch) + +``` +verl/ +β”œβ”€β”€ checkpoint_engine/mx_checkpoint_engine.py # MxCheckpointEngine +β”œβ”€β”€ checkpoint_engine/__init__.py # Optional import (+7 lines) +└── workers/config/rollout.py # "mx" in backend comment (+1 line) +``` From 5d4661015eb75084f81e8b45d625fbabc6e44326 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Tue, 14 Apr 2026 13:43:49 -0700 Subject: [PATCH 11/25] docs: add component architecture diagrams for PRIME-RL and verl POCs Made-with: Cursor Signed-off-by: Kavin Krishnan --- docs/MX_RL_OVERVIEW.md | 123 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 123 insertions(+) diff --git a/docs/MX_RL_OVERVIEW.md b/docs/MX_RL_OVERVIEW.md index 938bfc05..be62e7fe 100644 --- a/docs/MX_RL_OVERVIEW.md +++ b/docs/MX_RL_OVERVIEW.md @@ -101,6 +101,63 @@ An async-first RL framework by PrimeIntellect with three separate processes: No Ray dependency. Raw Kubernetes pods. +### Architecture + +```mermaid +graph LR + subgraph cluster["GKE GB200 Cluster"] + direction TB + + subgraph control["Control Plane Β· CPU"] + direction LR + orch["Orchestrator"] + mx["MX Server + Redis"] + end + + subgraph gpus["GPU Pods Β· 4x GB200 each"] + direction LR + + subgraph tp["Trainer Pod"] + direction TB + fsdp["FSDP2 Training"] + pub["MX Publisher"] + nt(["NIXL Agent"]) + fsdp --> pub --> nt + end + + subgraph ip["Inference Pod"] + direction TB + vllm["vLLM Server"] + recv["MX Receiver"] + ni(["NIXL Agent"]) + ni --> recv --> vllm + end + end + + orch -. "HTTP rollouts" .-> vllm + orch -. "HTTP update_weights" .-> recv + pub -- "gRPC publish" --> mx + recv -- "gRPC discover" --> mx + nt <== "RDMA RoCE Β· 261-330 Gbps" ==> ni + end + + style cluster fill:#1a1a2e,stroke:#16213e,color:#e0e0e0 + style control fill:#1a1a2e,stroke:#533483,color:#e0e0e0 + style gpus fill:#0f3460,stroke:#533483,color:#e0e0e0 + style tp fill:#162447,stroke:#533483,color:#e0e0e0 + style ip fill:#162447,stroke:#533483,color:#e0e0e0 + style fsdp fill:#533483,stroke:#e94560,color:#fff + style vllm fill:#533483,stroke:#e94560,color:#fff + style orch fill:#533483,stroke:#e94560,color:#fff + style pub fill:#1b5e20,stroke:#4caf50,color:#fff + style recv fill:#1b5e20,stroke:#4caf50,color:#fff + style mx fill:#1b5e20,stroke:#4caf50,color:#fff + style nt fill:#2e7d32,stroke:#66bb6a,color:#fff + style ni fill:#2e7d32,stroke:#66bb6a,color:#fff +``` + +Green = ModelExpress / NIXL components. Purple = existing framework components. + ### Integration Summary | Item | Details | @@ -164,6 +221,72 @@ UCX 1.20+, IMEX channels via DRA, and `/dev/infiniband` are all required for cro A production-grade RL framework by ByteDance that uses Ray for orchestration. Supports FSDP, Megatron, vLLM, SGLang, TRT-LLM. Has a `CheckpointEngine` plugin system for weight transfer. +### Architecture + +```mermaid +graph LR + subgraph cluster["Ray Cluster Β· GKE GB200"] + direction TB + + subgraph driver["Driver Β· CPU"] + task["TaskRunner"] + mgr["CheckpointEngine Manager"] + end + + subgraph trainer_wg["Trainer WorkerGroup Β· GPU"] + direction LR + tw0["Worker 0
FSDP2 + MX CE"] + tw1["Worker 1
FSDP2 + CE"] + tw2["Worker 2
FSDP2 + CE"] + tw3["Worker 3
FSDP2 + CE"] + end + + subgraph rollout_wg["Rollout Replicas Β· GPU"] + direction LR + rw0["CE Worker 0"] + rw1["CE Worker 1"] + rw2["CE Worker 2"] + rw3["CE Worker 3"] + vs["vLLM Server"] + end + + subgraph mx_svc["MX Server + Redis Β· CPU"] + mx["gRPC Metadata Broker"] + end + + task --> mgr + mgr -. "ray.get" .-> tw0 + mgr -. "ray.get" .-> rw0 + tw0 -- "gRPC publish" --> mx + rw0 -- "gRPC discover" --> mx + tw0 <== "NIXL RDMA" ==> rw0 + tw1 <== "NIXL RDMA" ==> rw1 + tw2 <== "NIXL RDMA" ==> rw2 + tw3 <== "NIXL RDMA" ==> rw3 + rw0 -. "CUDA IPC" .-> vs + end + + style cluster fill:#1a1a2e,stroke:#16213e,color:#e0e0e0 + style driver fill:#1a1a2e,stroke:#533483,color:#e0e0e0 + style trainer_wg fill:#0f3460,stroke:#533483,color:#e0e0e0 + style rollout_wg fill:#0f3460,stroke:#533483,color:#e0e0e0 + style mx_svc fill:#1a1a2e,stroke:#4caf50,color:#e0e0e0 + style task fill:#533483,stroke:#e94560,color:#fff + style mgr fill:#1b5e20,stroke:#4caf50,color:#fff + style tw0 fill:#1b5e20,stroke:#4caf50,color:#fff + style tw1 fill:#162447,stroke:#533483,color:#e0e0e0 + style tw2 fill:#162447,stroke:#533483,color:#e0e0e0 + style tw3 fill:#162447,stroke:#533483,color:#e0e0e0 + style rw0 fill:#1b5e20,stroke:#4caf50,color:#fff + style rw1 fill:#2e7d32,stroke:#66bb6a,color:#fff + style rw2 fill:#2e7d32,stroke:#66bb6a,color:#fff + style rw3 fill:#2e7d32,stroke:#66bb6a,color:#fff + style vs fill:#533483,stroke:#e94560,color:#fff + style mx fill:#1b5e20,stroke:#4caf50,color:#fff +``` + +Green = MX checkpoint engine components. Purple = existing verl/Ray components. The `CheckpointEngineManager` coordinates the MX CE workers on both trainer and rollout sides. Each trainer-rollout rank pair transfers via NIXL RDMA. Received weights reach vLLM via CUDA IPC through the existing `ServerAdapter`. + ### Why It's Easier Than PRIME-RL verl already has: From b30bf8f86c323533ddc13a13f30c52de3ce1e141 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Wed, 22 Apr 2026 11:51:55 -0700 Subject: [PATCH 12/25] =?UTF-8?q?docs:=20add=20verl=20=C3=97=20ModelExpres?= =?UTF-8?q?s=20integration=20overview=20with=20vertical=20diagrams?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Covers the MxCheckpointEngine design, Ray actor topology, and GB200 prototype results (10 steps, avg ~1.25s cross-node RDMA weight sync). Made-with: Cursor Signed-off-by: Kavin Krishnan --- docs/RL/VERL_MX_OVERVIEW.md | 491 ++++++++++++++++++++++++++++++++++++ 1 file changed, 491 insertions(+) create mode 100644 docs/RL/VERL_MX_OVERVIEW.md diff --git a/docs/RL/VERL_MX_OVERVIEW.md b/docs/RL/VERL_MX_OVERVIEW.md new file mode 100644 index 00000000..e7df334d --- /dev/null +++ b/docs/RL/VERL_MX_OVERVIEW.md @@ -0,0 +1,491 @@ +# ModelExpress Γ— verl β€” Design Overview + +**Last Updated**: April 2026 +**Status**: E2E working β€” cross-node RDMA weight transfers via `MxCheckpointEngine` on 2Γ— GB200 nodes (GKE). + +This document covers how ModelExpress (MX) plugs into [verl](https://github.com/volcengine/verl) for RL post-training weight synchronization. It walks through the component design, the Ray actor integration, the `CheckpointEngine` surface, and the GB200 prototype results. + +--- + +## 1. Design Overview + +verl is a Ray-orchestrated RL framework. Its `CheckpointEngine` plugin system is the seam where MX slots in. `MxCheckpointEngine` replaces the default `naive` sync (process-local copy) or the built-in `nixl` ring engine with a **star topology over RDMA**, coordinated by the MX Server. + +### What MX adds to verl + +| Layer | Role | Implementation | +|-------|------|----------------| +| Metadata plane | Source discovery, version tracking, topology coordination | MX Server (gRPC) + Redis | +| Data plane | GPU-to-GPU tensor transport | NIXL (UCX / `rc_mlx5` / RoCE) | +| verl integration | `CheckpointEngine` ABC implementation | `verl/checkpoint_engine/mx_checkpoint_engine.py` (461 lines) | +| Transport choreography | Bucket metadata handshake per transfer | ZMQ PUSH/PULL | + +### Component diagram (vertical, document-friendly) + +```mermaid +graph TB + subgraph driver["Driver Β· Ray head Β· CPU"] + task["TaskRunner
(Ray actor)"] + mgr["CheckpointEngineManager"] + task --> mgr + end + + subgraph mx_meta["Metadata Plane Β· CPU"] + mx["MX Server
(gRPC)"] + redis[("Redis")] + mx --> redis + end + + subgraph trainer["Trainer node Β· 4Γ— GB200"] + direction TB + tw["WorkerDict Γ— 4
FSDP2 + optimizer"] + tce["MxCheckpointEngine
(trainer role)"] + tnixl(["NIXL Agent Γ— 4"]) + tw --> tce --> tnixl + end + + subgraph rollout["Rollout node Β· 4Γ— GB200"] + direction TB + cew["CheckpointEngineWorker Γ— 4
(standalone replicas)"] + rce["MxCheckpointEngine
(rollout role)"] + rnixl(["NIXL Agent Γ— 4"]) + vllm["vLLM Server Γ— 4
(ServerAdapter)"] + cew --> rce --> rnixl + rce -. "load_weights" .-> vllm + end + + mgr -- "ray.get(prepare/send)" --> tw + mgr -- "ray.get(prepare/recv)" --> cew + tce -- "gRPC publish
(agent_meta, bucket)" --> mx + rce -- "gRPC discover
(poll_for_source)" --> mx + tce -. "ZMQ bucket metadata
(name, shape, dtype, offset)" .-> rce + tnixl <== "NIXL RDMA READ
RoCE Β· rc_mlx5" ==> rnixl + + style driver fill:#1a1a2e,stroke:#533483,color:#e0e0e0 + style mx_meta fill:#1a1a2e,stroke:#4caf50,color:#e0e0e0 + style trainer fill:#0f3460,stroke:#533483,color:#e0e0e0 + style rollout fill:#0f3460,stroke:#533483,color:#e0e0e0 + style task fill:#533483,stroke:#e94560,color:#fff + style mgr fill:#533483,stroke:#e94560,color:#fff + style tw fill:#533483,stroke:#e94560,color:#fff + style cew fill:#2e7d32,stroke:#66bb6a,color:#fff + style vllm fill:#533483,stroke:#e94560,color:#fff + style tce fill:#1b5e20,stroke:#4caf50,color:#fff + style rce fill:#1b5e20,stroke:#4caf50,color:#fff + style tnixl fill:#2e7d32,stroke:#66bb6a,color:#fff + style rnixl fill:#2e7d32,stroke:#66bb6a,color:#fff + style mx fill:#1b5e20,stroke:#4caf50,color:#fff + style redis fill:#162447,stroke:#533483,color:#e0e0e0 +``` + +**Legend**: Green boxes are MX/NIXL additions. Purple boxes are existing verl/Ray/vLLM components. The diagram is rendered top-to-bottom (`graph TB`) so it fits in a single document column without horizontal scrolling. + +### Key ideas + +- **MX Server stores metadata only** β€” tensor names, GPU memory addresses, NIXL agent blobs, version numbers. It never touches weight bytes. +- **The heavy transfer is a one-sided RDMA READ** from the rollout's NIXL agent into the trainer's GPU memory, going GPU-direct over RoCE via `rc_mlx5`. +- **Star topology, not a ring.** verl's built-in NIXL engine uses a ring; MX uses the server as a central rendezvous, which is simpler to reason about and sets up future pipeline replication (rollouts can become secondary sources). +- **Bucketed transfer preserves shapes.** Unlike the PRIME-RL POC (which needs scratch buffers), verl's `CheckpointEngine` passes a tensor generator with names and shapes. MX packs them into GPU buckets and the receiver pulls them out by offset β€” no reshape tricks required. + +--- + +## 2. Timing Diagram β€” One `update_weights` Step + +```mermaid +sequenceDiagram + participant D as Driver
(CheckpointEngineManager) + participant T as Trainer WorkerDict + participant CE_T as MxCheckpointEngine
(trainer) + participant MX as MX Server + participant CE_R as MxCheckpointEngine
(rollout) + participant R as CheckpointEngineWorker
(rollout) + participant V as vLLM ServerAdapter + + Note over T: optimizer.step() complete + + D->>T: ray.remote: update_weights(step=N) + D->>R: ray.remote: update_weights(step=N) + + par prepare() on both sides + T->>CE_T: prepare() + CE_T->>CE_T: allocate GPU send bucket
register with NIXL + CE_T->>MX: publish agent_meta + tensor layout + R->>CE_R: prepare() + CE_R->>CE_R: allocate GPU recv bucket
register with NIXL + CE_R->>MX: poll_for_source(model_name) + MX-->>CE_R: trainer agent_meta + end + + D->>D: build_topology()
rank 0 β†’ all rollouts + + par init_process_group + CE_T->>CE_T: add rollout agents (NIXL) + CE_R->>CE_R: add trainer agent (NIXL) + CE_T-->>CE_R: ZMQ handshake (ip:port) + end + + loop per bucket (streamed) + T->>CE_T: send_weights(yield name, tensor) + CE_T->>CE_T: pack tensor into GPU bucket + CE_T->>CE_R: ZMQ: bucket desc (name, shape, dtype, offset) + CE_R->>CE_T: NIXL RDMA READ
(GPU→GPU, RoCE) + CE_R->>R: yield (name, tensor_view) + R->>V: load_weights([(name, tensor)]) + end + + par finalize + T->>CE_T: finalize() — deregister buffers + R->>CE_R: finalize() — deregister buffers + end + + Note over V: vLLM resumes with updated weights +``` + +**Observed per-step timing** (GB200, Qwen2.5-1.5B BF16, cross-node RoCE): + +| Phase | Wall time | +|-------|-----------| +| `prepare` + `build_topology` + `init_process_group` | ~0.3-0.4s | +| `send_weights` / `receive_weights` (RDMA) | ~0.6-0.8s | +| `finalize` | ~0.1s | +| **Total `update_weights`** | **~1.25s avg** (range 1.22-1.28s) | + +For the same model and cluster, the default `naive` engine averages **~1.6s** (in-process copy). The MX engine is faster *and* does a real cross-node transfer — the naive baseline only works because hybrid mode colocates trainer and rollout on the same GPUs. + +--- + +## 3. ModelExpress and the Ray Actor Design + +verl's runtime is a web of Ray actors. Understanding where `MxCheckpointEngine` lives inside that web is the key to the integration. + +### Actor topology + +```mermaid +graph TB + subgraph head["Ray Head · Node 1 · CPU driver"] + direction TB + tr["TaskRunner
(@ray.remote)"] + ceman["CheckpointEngineManager
(driver-side orchestrator)"] + tr --> ceman + end + + subgraph trainer_pg["Trainer Placement Group Β· Node 1 Β· 4 GPUs"] + direction TB + wd0["WorkerDict 0
ActorRolloutRefWorker
FSDP2 engine"] + wd1["WorkerDict 1"] + wd2["WorkerDict 2"] + wd3["WorkerDict 3"] + mxt["MxCheckpointEngine
instance (per worker)"] + wd0 --> mxt + end + + subgraph rollout_pg["Rollout Placement Group Β· Node 2 Β· 4 GPUs"] + direction TB + cew0["CheckpointEngineWorker 0
(@ray.remote)"] + cew1["CheckpointEngineWorker 1"] + cew2["CheckpointEngineWorker 2"] + cew3["CheckpointEngineWorker 3"] + mxr["MxCheckpointEngine
instance (per worker)"] + vllm["vLLM Server actor Γ— 4
(ServerAdapter)"] + cew0 --> mxr + mxr --> vllm + end + + subgraph meta["MX Server Actor Group Β· CPU"] + mx["MX gRPC Server"] + rd[("Redis")] + mx --> rd + end + + ceman -- "ray.get
execute_checkpoint_engine(send)" --> wd0 + ceman -- "ray.get
execute_checkpoint_engine(recv)" --> cew0 + mxt -- "gRPC" --> mx + mxr -- "gRPC" --> mx + mxt <== "NIXL RDMA
(rank-paired)" ==> mxr + + style head fill:#1a1a2e,stroke:#533483,color:#e0e0e0 + style trainer_pg fill:#0f3460,stroke:#533483,color:#e0e0e0 + style rollout_pg fill:#0f3460,stroke:#533483,color:#e0e0e0 + style meta fill:#1a1a2e,stroke:#4caf50,color:#e0e0e0 + style tr fill:#533483,stroke:#e94560,color:#fff + style ceman fill:#533483,stroke:#e94560,color:#fff + style wd0 fill:#533483,stroke:#e94560,color:#fff + style wd1 fill:#162447,stroke:#533483,color:#e0e0e0 + style wd2 fill:#162447,stroke:#533483,color:#e0e0e0 + style wd3 fill:#162447,stroke:#533483,color:#e0e0e0 + style cew0 fill:#2e7d32,stroke:#66bb6a,color:#fff + style cew1 fill:#2e7d32,stroke:#66bb6a,color:#fff + style cew2 fill:#2e7d32,stroke:#66bb6a,color:#fff + style cew3 fill:#2e7d32,stroke:#66bb6a,color:#fff + style vllm fill:#533483,stroke:#e94560,color:#fff + style mxt fill:#1b5e20,stroke:#4caf50,color:#fff + style mxr fill:#1b5e20,stroke:#4caf50,color:#fff + style mx fill:#1b5e20,stroke:#4caf50,color:#fff + style rd fill:#162447,stroke:#533483,color:#e0e0e0 +``` + +### Three actor classes that matter + +1. **`TaskRunner`** β€” a single CPU Ray actor that owns the training loop. It holds the `CheckpointEngineManager` and drives PPO/GRPO iteration. +2. **`WorkerDict` (trainer side)** β€” a Ray GPU actor per trainer rank. Hosts the FSDP2 model, optimizer, and β€” under our integration β€” an `MxCheckpointEngine` instance for the trainer role. +3. **`CheckpointEngineWorker` (rollout side)** β€” a dedicated Ray GPU actor per rollout rank. It exists only in **standalone mode** (rollout on its own GPU pool). It hosts the `MxCheckpointEngine` in the rollout role and drives `ServerAdapter.load_weights` into the colocated vLLM engine. + +### Why standalone mode matters + +verl has two deployment modes for the rollout: + +| Mode | Ray actors | Status for MX | +|------|-----------|--------------| +| **Hybrid (colocated)** | `WorkerDict` does both training and rollout | ❌ No `execute_checkpoint_engine` method β€” `CheckpointEngineManager` fails | +| **Standalone (disaggregated)** | Trainer uses `WorkerDict`, rollout uses `CheckpointEngineWorker` | βœ… Full CE lifecycle available | + +This is a verl framework constraint, not an MX constraint β€” the built-in `nixl` and `nccl` engines have the same requirement. Our prototype runs in standalone mode on 2 nodes. + +### How a weight sync crosses the actor boundary + +``` +TaskRunner (Node 1) + └─► CheckpointEngineManager.update_weights(step=N) # driver-side + β”œβ”€β–Ί ray.get([wd0.execute_checkpoint_engine("prepare"), # fan-out to trainer + β”‚ wd1.execute_checkpoint_engine("prepare"), + β”‚ wd2.execute_checkpoint_engine("prepare"), + β”‚ wd3.execute_checkpoint_engine("prepare")]) + β”œβ”€β–Ί ray.get([cew0.execute_checkpoint_engine("prepare"), # fan-out to rollout + β”‚ cew1...cew3.execute_checkpoint_engine("prepare")]) + β”œβ”€β–Ί build_topology(agent_meta_list) # computed on driver + β”œβ”€β–Ί ray.get([.init_process_group(topology) on both sides]) + β”œβ”€β–Ί ray.get([wd0..3.send_weights(generator), cew0..3.receive_weights()]) # the RDMA moment + └─► ray.get([.finalize() on both sides]) +``` + +The `execute_checkpoint_engine("method", *args)` pattern is how the manager dispatches a named lifecycle call onto every CE-hosting actor in parallel. Every `ray.get` is a fan-out over 4 trainer + 4 rollout actors, so all 8 ranks move in lock-step. + +### Where `MxCheckpointEngine` is instantiated + +- On the **trainer**: `ActorRolloutRefWorker.init_model()` creates the engine when the config sets `actor_rollout_ref.rollout.checkpoint_engine.backend=mx`. The engine registers to handle `send_weights`. +- On the **rollout**: `CheckpointEngineWorker.__init__` constructs the same class (via the `CheckpointEngineRegistry`) but with `role="rollout"`. It registers to handle `receive_weights` and to drive `load_weights` into the ServerAdapter. + +One class, two roles, distinguished by which actor type instantiates it. Both sides talk to the same MX Server over gRPC. + +--- + +## 4. ModelExpress and the Checkpoint Engine + +verl's `CheckpointEngine` ABC is the small, well-defined plugin surface that makes MX a drop-in. + +### The ABC + +```python +class CheckpointEngine(ABC): + def prepare(self) -> dict: ... # allocate buffers, NIXL register + @classmethod + def build_topology(cls, trainer_meta, rollout_meta) -> tuple: ... + def init_process_group(self, topology, rank) -> None: ... + async def send_weights(self, weights_iter) -> None: ... # trainer + async def receive_weights(self) -> AsyncGenerator: ... # rollout + def finalize(self) -> None: ... +``` + +All six methods are implemented in `verl/checkpoint_engine/mx_checkpoint_engine.py` (461 lines). The class is registered with one decorator: + +```python +@CheckpointEngineRegistry.register("mx") +class MxCheckpointEngine(CheckpointEngine): + ... +``` + +…which makes `backend: "mx"` selectable from Hydra config. + +### Lifecycle responsibilities + +```mermaid +graph LR + subgraph life["MxCheckpointEngine lifecycle (one update_weights step)"] + direction LR + P["prepare()
alloc GPU bucket
NIXL register
MX publish/discover"] + B["build_topology()
star: rank 0 β†’ all rollouts"] + I["init_process_group()
NIXL add_remote_agent
ZMQ handshake"] + S["send_weights()
pack bucket
push ZMQ desc
wait for RDMA"] + R["receive_weights()
pull ZMQ desc
NIXL RDMA READ
yield tensors"] + F["finalize()
dereg NIXL
close ZMQ"] + P --> B --> I + I --> S + I --> R + S --> F + R --> F + end + + style life fill:#1a1a2e,stroke:#4caf50,color:#e0e0e0 + style P fill:#1b5e20,stroke:#4caf50,color:#fff + style B fill:#1b5e20,stroke:#4caf50,color:#fff + style I fill:#1b5e20,stroke:#4caf50,color:#fff + style S fill:#1b5e20,stroke:#4caf50,color:#fff + style R fill:#1b5e20,stroke:#4caf50,color:#fff + style F fill:#1b5e20,stroke:#4caf50,color:#fff +``` + +### Method-by-method behavior + +| Method | Trainer side | Rollout side | +|--------|--------------|--------------| +| `prepare()` | Allocate pinned GPU send bucket, register with NIXL agent, publish agent metadata + tensor layout to MX Server via gRPC | Allocate GPU recv bucket, register with NIXL, call `MxClient.poll_for_source(model_name)` to get the trainer's agent blob | +| `build_topology()` | Driver-side utility. Produces `(trainer_agent β†’ [rollout_agents])` star mapping | Same (called on driver) | +| `init_process_group()` | For each rollout rank: `nixl_agent.add_remote_agent(rollout_meta)` | `nixl_agent.add_remote_agent(trainer_meta)`; ZMQ PULL socket bound on free port, advertised to trainer | +| `send_weights(iter)` | Consume `(name, tensor)` generator. Pack tensors into the GPU bucket at known offsets. Send `BucketDesc{name, shape, dtype, offset, nbytes}` over ZMQ PUSH. Block until rollout signals ACK | β€” | +| `receive_weights()` | β€” | Pull `BucketDesc` over ZMQ. Issue NIXL RDMA READ into the recv bucket at the given offset. Yield `(name, tensor_view)` to verl, which forwards to `ServerAdapter.load_weights` | +| `finalize()` | Deregister NIXL memory regions, close ZMQ, tell MX Server to retire this version | Same | + +### Why a bucket, not per-tensor RDMA + +Each NIXL RDMA transfer has a fixed latency overhead (~50-100Β΅s). A 1.5B model has hundreds of parameters. Issuing per-tensor transfers would spend more time in transfer setup than in data movement. The engine packs tensors into a contiguous GPU bucket (up to a configured size, e.g. 256 MB) and issues one RDMA READ per bucket. The ZMQ channel carries a list of `BucketDesc` entries so the receiver can slice the bucket back into named tensors. + +This is the same pattern used by verl's built-in NIXL engine. MX adopts it for parity β€” and because it works. + +### Config + +```yaml +actor_rollout_ref: + rollout: + checkpoint_engine: + backend: mx + engine_kwargs: + mx_server_url: modelexpress-server.kavin.svc.cluster.local:8001 + model_name: Qwen/Qwen2.5-1.5B + bucket_size_mb: 256 # optional, default 256 + skip_sleep_wake: true # avoid vLLM multiproc sleep/wake crash on ARM64 +``` + +### What differs from PRIME-RL + +| Concern | PRIME-RL | verl / MxCheckpointEngine | +|---------|----------|---------------------------| +| Plugin point | Custom `WeightBroadcast` ABC + vLLM worker extension | `CheckpointEngine` ABC (native, already has NIXL and NCCL siblings) | +| Shape handling | Scratch buffers + safetensors header reshape | Bucket carries `(name, shape, dtype)` β€” no reshape needed | +| Fused params (Q/K/V, gate/up) | Rely on `model.load_weights()` to fuse from HF names | Trainer publishes already-in-target-format buckets; rollout passes through to `load_weights` | +| Allgather on trainer | Rank 0 gathers FSDP shards before publish | FSDP shards are packed per-rank; star topology fans out to rollout ranks | + +--- + +## 5. Prototype on GB200 β€” Results + +### Cluster + +| Resource | Value | +|----------|-------| +| Platform | GKE DGXCloud, GB200 ARM64 | +| Nodes | 2 (trainer + rollout), `hostNetwork=true` | +| GPUs | 8 Γ— GB200 (4 per node) | +| Node pools | `customer-gpu-w0e` (trainer), `customer-gpu-o7v` (rollout) | +| Fabric | RoCE v2, IMEX channels via DRA `compute-domain-channel` | +| UCX | v1.20.0 built from source, transports `self,sm,rc,cuda_copy,gdr_copy,tcp` | +| NIXL | 1.1.0 (main branch) | +| PyTorch | 2.6 + cu128 | +| vLLM | 0.18.1 (0.19.0 has an ARM64 multiproc `resource_tracker` bug) | +| Image | `nvcr.io/nvidian/dynamo-dev/verl-mx:latest` | + +### Deployment + +``` +Node 1 (gke-...-w0e-...-tz1d, IP 10.0.0.83) +β”œβ”€ Ray head StatefulSet (verl-mx-head-0) +β”‚ β”œβ”€ TaskRunner / CheckpointEngineManager +β”‚ └─ 4Γ— WorkerDict (FSDP2 trainers) + 4Γ— MxCheckpointEngine +└─ 4Γ— NIXL agent (rc_mlx5) + +Node 2 (gke-...-o7v-...-mflg, IP 10.0.15.225) +β”œβ”€ Ray worker StatefulSet (verl-mx-worker-0) +β”‚ β”œβ”€ 4Γ— CheckpointEngineWorker + 4Γ— MxCheckpointEngine +β”‚ └─ 4Γ— vLLM ServerAdapter +└─ 4Γ— NIXL agent (rc_mlx5) + +MX Server + Redis (kavin namespace, reachable from both nodes over gRPC) +``` + +### What we observed + +- `[MX-DEBUG] Initializing 4 replicas in STANDALONE mode (worker_group=None)` β€” rollout running as dedicated CE workers, not fused WorkerDicts. +- `[MX-DEBUG] Standalone replicas ([STANDALONE x4]), using mx checkpoint engine` β€” repeated 11 times across the run (one per `update_weights`). +- `Backend UCX was instantiated` + `Initialized NIXL agent: ` on both nodes. +- UCX `rc_mlx5` transport negotiated β€” confirmed RoCE data path. +- Full lifecycle traced per step: `prepare β†’ build_topology β†’ init_process_group β†’ send_weights / receive_weights β†’ finalize`. + +### Per-step `update_weights` timing (MX engine, cross-node RDMA) + +| Step | `update_weights` (s) | +|------|---------------------| +| 1 | 1.278 | +| 2 | 1.250 | +| 3 | 1.233 | +| 4 | 1.252 | +| 5 | 1.243 | +| 6 | 1.223 | +| 7 | 1.235 | +| 8 | 1.249 | +| 9 | 1.263 | +| 10 | 1.282 | +| **Avg** | **β‰ˆ 1.25s** | + +### Headline metrics + +| Metric | Value | +|--------|-------| +| Model | Qwen/Qwen2.5-1.5B (BF16, β‰ˆ 3 GB resident) | +| Steps completed | 10 (full PPO/GRPO run) | +| Avg step time | ~8.1-8.8s | +| Avg `update_weights` (MX) | **~1.25s** | +| Avg `update_weights` (naive baseline, hybrid mode) | ~1.6s | +| Throughput | 135-163 tokens/sec | +| Transport | NIXL / UCX `rc_mlx5` (RoCE RDMA) | +| Data path | Cross-node GPUβ†’GPU (no CPU staging, no filesystem) | + +### Manual NIXL transfer test (isolated, inside one pod) + +Before wiring the engine into the training loop, we ran a standalone `MxTrainingPublisher` β†’ `MxRefitReceiver` test to validate the MX/NIXL data plane by itself: + +| Metric | Value | +|--------|-------| +| Payload | 10 tensors Γ— 2 MB = 21 MB | +| Transfer time | 0.16s | +| Data integrity | All tensors byte-verified correct | +| Environment | Loopback (same GPU, single pod) | + +This proved publisher/receiver correctness before moving to cross-node. + +### How we got there β€” build history + +21 Docker image iterations were needed to reach a working ARM64 build. Dominant issues and their fixes: + +| Category | Resolution | +|----------|-----------| +| NIXL build (missing tag, `pybind11`) | Clone NIXL `main`, add `pybind11-dev` to apt install | +| PyTorch CPU-only on ARM64 | Install `torch==2.6` with `--index-url` `download.pytorch.org/whl/cu128` *before* vLLM; reinstall matching version after | +| `flash_attn` absent on ARM64 | Ship `flash_attn_compat.py` in `verl/utils` β€” real SDPA fallbacks for GQA attention and cross-entropy | +| vLLM 0.19 multiproc crash | Downgrade to 0.18.1 (stable on ARM64) | +| Triton JIT β€” missing gcc / nvcc | Base runtime on `cuda:12.8.1-devel`, add `gcc/g++` | +| `cupy` not installed | Made optional in `MxCheckpointEngine`, added torch fallback for bucket allocation | +| Ray worker β†’ head GCS | Use head's FQDN `verl-mx-head-0.verl-mx-head.kavin.svc.cluster.local:6379` | +| `CheckpointEngineManager` fails in hybrid mode | Deploy in standalone (2-node) mode β€” matches built-in NIXL/NCCL engine requirements | + +### What this proves + +1. **MX works on Ray.** The `CheckpointEngine` plugin surface is sufficient to express a star-topology RDMA transfer with server-mediated discovery. +2. **Cross-node RoCE RDMA is real.** `update_weights` at 1.25s for a 3 GB model on a 2-node GB200 cluster is consistent with UCX `rc_mlx5` over RoCE and beats the in-process naive baseline even before we tune the bucket size. +3. **The ARM64 path is painful but survivable.** All the image work is in `docker/Dockerfile.mx-arm64` and the compat shim, and is shared with (and borrowed from) the PRIME-RL POC. +4. **Standalone rollout is the production shape.** Hybrid/colocated mode is useful for debugging but cannot drive any non-naive checkpoint engine β€” true for NIXL, NCCL, and MX alike. + +--- + +## 6. Related Documents + +- **PRIME-RL POC**: `recovery/reinforcement learning/PRIMERL_MX_NIXL_Overview.md` +- **verl POC state log**: `recovery/reinforcement learning/VERL_POC_STATE.md` +- **verl design log**: `recovery/reinforcement learning/VERL_RAY_MX.md` +- **General MX for RL design**: `docs/MX_RL_OVERVIEW.md` +- **TensorHub comparison** (pipeline replication roadmap): `recovery/reinforcement learning/TensorHub_Analysis.md` + +### Upstream repos + +| Repo | Branch | Key files | +|------|--------|-----------| +| `github.com/KavinKrishnan/verl` | `kavink/mx-checkpoint-engine` | `verl/checkpoint_engine/mx_checkpoint_engine.py`, `verl/utils/flash_attn_compat.py`, `docker/Dockerfile.mx-arm64`, `k8s/verl-mx-poc/*` | +| `github.com/ai-dynamo/modelexpress` | `kavink/RL` | `training_publisher.py`, `refit_receiver.py`, `nixl_transfer.py`, `client.py` | From 8991660112c66b49bc30fb4779fd1a2a15771cbb Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Wed, 22 Apr 2026 12:34:34 -0700 Subject: [PATCH 13/25] docs(verl): remove related-documents section with internal-only paths The section referenced recovery/ paths outside the ModelExpress repo that aren't accessible to external readers. Made-with: Cursor Signed-off-by: Kavin Krishnan --- docs/RL/VERL_MX_OVERVIEW.md | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/docs/RL/VERL_MX_OVERVIEW.md b/docs/RL/VERL_MX_OVERVIEW.md index e7df334d..d5ced9b2 100644 --- a/docs/RL/VERL_MX_OVERVIEW.md +++ b/docs/RL/VERL_MX_OVERVIEW.md @@ -473,19 +473,3 @@ This proved publisher/receiver correctness before moving to cross-node. 3. **The ARM64 path is painful but survivable.** All the image work is in `docker/Dockerfile.mx-arm64` and the compat shim, and is shared with (and borrowed from) the PRIME-RL POC. 4. **Standalone rollout is the production shape.** Hybrid/colocated mode is useful for debugging but cannot drive any non-naive checkpoint engine β€” true for NIXL, NCCL, and MX alike. ---- - -## 6. Related Documents - -- **PRIME-RL POC**: `recovery/reinforcement learning/PRIMERL_MX_NIXL_Overview.md` -- **verl POC state log**: `recovery/reinforcement learning/VERL_POC_STATE.md` -- **verl design log**: `recovery/reinforcement learning/VERL_RAY_MX.md` -- **General MX for RL design**: `docs/MX_RL_OVERVIEW.md` -- **TensorHub comparison** (pipeline replication roadmap): `recovery/reinforcement learning/TensorHub_Analysis.md` - -### Upstream repos - -| Repo | Branch | Key files | -|------|--------|-----------| -| `github.com/KavinKrishnan/verl` | `kavink/mx-checkpoint-engine` | `verl/checkpoint_engine/mx_checkpoint_engine.py`, `verl/utils/flash_attn_compat.py`, `docker/Dockerfile.mx-arm64`, `k8s/verl-mx-poc/*` | -| `github.com/ai-dynamo/modelexpress` | `kavink/RL` | `training_publisher.py`, `refit_receiver.py`, `nixl_transfer.py`, `client.py` | From 90e45eabe414ecad8af91a8fce0baf4a7653ae77 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Wed, 22 Apr 2026 15:19:46 -0700 Subject: [PATCH 14/25] docs(RL): cross-reference draft overlay PR #2343 in the design doc Made-with: Cursor Signed-off-by: Kavin Krishnan --- docs/RL/PRIMERL_MX_OVERVIEW.md | 707 +++++++++++++++++++++++++++++++++ 1 file changed, 707 insertions(+) create mode 100644 docs/RL/PRIMERL_MX_OVERVIEW.md diff --git a/docs/RL/PRIMERL_MX_OVERVIEW.md b/docs/RL/PRIMERL_MX_OVERVIEW.md new file mode 100644 index 00000000..d381a5c8 --- /dev/null +++ b/docs/RL/PRIMERL_MX_OVERVIEW.md @@ -0,0 +1,707 @@ +# ModelExpress Γ— PRIME-RL β€” Design Overview + +**Last Updated**: April 2026 +**Status**: Design complete; prototype overlay on top of [PrimeIntellect-ai/prime-rl#2326](https://github.com/PrimeIntellect-ai/prime-rl/pull/2326) targeting GB200 (ARM64, GKE). Metrics sections below are populated as the benchmark run produces data. + +This document covers how ModelExpress (MX) plugs into [PRIME-RL](https://github.com/PrimeIntellect-ai/prime-rl)'s NIXL weight-transfer path as a **metadata and elasticity layer on top of** the existing `NIXLWeightBroadcast` / `TransportPlan` introduced by PR #2326. We do not reimplement their transport. We replace the SPG (StatelessProcessGroup) rendezvous with an MX-Server-mediated discovery plane, add pipeline replication, add a mutability contract, and enable a scratch-buffer diagnostic mode β€” all opt-in behind a single config flag. + +--- + +## 1. Design Overview + +### What MX adds to PRIME-RL's NIXL backend + +PR #2326 gives PRIME-RL a bit-exact RDMA weight transport built on NIXL/UCX over RoCE, with slots (`ShardedSlot` / `GatheredSlot` / `ExpertSlot`), model-agnostic `ConversionSpec` / `QuantizationSpec`, FP8 trainer-side quantization, HSDP primary-replica push, per-rank NIC pinning, and an `expandable_segments`-safe `CUDAPluggableAllocator` slot pool. The transport works. What it doesn't have is a dynamic discovery plane. + +| Layer | Role in PR #2326 | Role with MX overlay | +|-------|------------------|----------------------| +| Data plane | NIXL RDMA (UCX / `rc_mlx5` / RoCE) | **Unchanged** β€” identical bytes on the wire | +| Slot / bucket system | `ShardedSlot`, `GatheredSlot`, `ExpertSlot`, `TransportPlan` | **Unchanged** β€” imported as-is | +| Publishing topology | **Per-rank sharding-aware** β€” each trainer rank publishes its own FSDP / TP / EP shard; no rank-0 allgather | **Unchanged** β€” this is a core property of PI's foundation, inherited by the overlay (see Β§3.9) | +| Quantization / conversion | `ConversionSpec`, `QuantizationSpec` | **Unchanged** β€” same trainer-side FP8 path | +| Rendezvous / discovery | SPG β€” static, rank-paired, global-world-size fixed at init | **Replaced by MX Server** (gRPC + Redis) when `rendezvous: mx_server` is set | +| Topology | Star (trainer rank k β†’ inference rank k, 1:1, no fan-in to rank 0) β€” trainer NIC is the single source for all fan-out | **Dynamic DAG** β€” trainer seeds the first rollout; each finished rollout becomes an additional source; MX Server load-balances new pollers across the growing source set. Same TensorHub pattern. Trainer NIC stops being a bottleneck once any rollout has received (Β§3.2) | +| Mutability contract | None | **Explicit `publish` / `unpublish`** β€” trainer publisher marks slots immutable during rollout pulls, mutable before `optimizer.step()` | +| Elastic topology | No β€” SPG locks `dp_shardΓ—cp + inference_ws` at boot | **Yes** β€” rollouts can join / leave mid-run via `poll_for_source` | +| Retention | None β€” no version history | **Keep-latest-N** β€” MX Server reaper preserves designated versions, CPU-offloads the last GPU copy if necessary | +| Cross-framework | prime-rl only | **Same MX client** also powers verl `MxCheckpointEngine`, future NeMo-RL | +| Expert-aware source tracking (MoE) | Implicit (ExpertSlot in client; no server-level index) | **Explicit server-side `(model, version, expert_id) β†’ worker` index** + `poll_for_expert_source` RPC. Low-hanging win β€” primitives already exist in MX Server, overlay wires them up (Β§3.7) | +| Peer recovery on pod restart | Not available β€” recovering rank must re-pull from trainer | **Multi-source discovery** β€” `poll_for_sources` returns ranked live peers holding the current version of rank k's shard; recovering rank pulls from nearest/least-loaded peer. Uses the same source index pipeline replication writes to. No event log / no version replay (Β§3.10) | +| Scratch-buffer diagnostic | Not supported (direct refit only) | **Opt-in via `transfer_mode: scratch`** β€” uses PI's same transport but stages into isolated GPU tensors + `model.load_weights()` for KL-drift triangulation | + +### Component diagram (vertical, document-friendly) + +```mermaid +graph TB + subgraph driver["Driver Β· CPU Β· orchestrator process"] + orch["RL Orchestrator
(existing)"] + httpapi["/pause /resume /update_weights
(vLLM WeightTransferEngine endpoints)"] + orch --> httpapi + end + + subgraph mx_meta["Metadata Plane Β· CPU"] + mx["MX Server
(gRPC)"] + redis[("Redis")] + mx --> redis + end + + subgraph trainer["Trainer node Β· FSDP2 + optimizer"] + direction TB + tw["Trainer ranks Γ— N
(dp_shard Γ— cp)"] + tp["NIXLWeightBroadcast
+ TransportPlan
(PI's code, unchanged)"] + pub["MxTrainingPublisher
(MX overlay)"] + tnixl(["NIXL Agent Γ— N"]) + tw --> tp + tp -->|slot registry| pub + tp --> tnixl + end + + subgraph rollout["Rollout nodes Β· vLLM TP"] + direction TB + cew["NIXLWeightUpdateWorker Γ— M
(PI's code, unchanged)"] + rcv["MxRefitReceiver
(MX overlay)"] + rnixl(["NIXL Agent Γ— M"]) + vllm["vLLM engine Γ— M
(live params)"] + cew --> rcv + cew --> rnixl + cew -. "in-place RDMA WRITE
or scratch-buffer stage" .-> vllm + end + + pub -- "gRPC publish_agent
slots + agent_meta + version" --> mx + rcv -- "gRPC poll_for_source
(model_name, worker_rank)" --> mx + mx -- "trainer agent_meta
slot layout + NIXL blob" --> rcv + tnixl <== "NIXL RDMA WRITE
RoCE Β· rc_mlx5
(PI transport, unchanged)" ==> rnixl + rcv -. "publish_rollout_source
(pipeline replication)" .-> mx + + style driver fill:#1a1a2e,stroke:#533483,color:#e0e0e0 + style mx_meta fill:#1a1a2e,stroke:#4caf50,color:#e0e0e0 + style trainer fill:#0f3460,stroke:#533483,color:#e0e0e0 + style rollout fill:#0f3460,stroke:#533483,color:#e0e0e0 + style orch fill:#533483,stroke:#e94560,color:#fff + style httpapi fill:#533483,stroke:#e94560,color:#fff + style tw fill:#533483,stroke:#e94560,color:#fff + style tp fill:#533483,stroke:#e94560,color:#fff + style cew fill:#533483,stroke:#e94560,color:#fff + style vllm fill:#533483,stroke:#e94560,color:#fff + style pub fill:#1b5e20,stroke:#4caf50,color:#fff + style rcv fill:#1b5e20,stroke:#4caf50,color:#fff + style mx fill:#1b5e20,stroke:#4caf50,color:#fff + style tnixl fill:#2e7d32,stroke:#66bb6a,color:#fff + style rnixl fill:#2e7d32,stroke:#66bb6a,color:#fff + style redis fill:#162447,stroke:#533483,color:#e0e0e0 +``` + +**Legend**: Green boxes = MX/NIXL additions (metadata plane + overlay client classes). Purple = existing PRIME-RL / vLLM / PI-PR-#2326 components. The trainer-to-rollout NIXL arrow is the exact same RDMA WRITE path PI introduced; MX does not touch the data plane. + +### Key ideas + +- **MX Server stores metadata only.** Slot layouts, tensor descriptors, NIXL agent blobs, version numbers. It never touches weight bytes. +- **The data path is PI's, unchanged.** NIXLWeightBroadcast + TransportPlan + Slot classes are imported and used as-is. Our value-add is what happens *before* (discovery) and *alongside* (lifecycle, pipeline replication, diagnostics). +- **Opt-in via one config field.** `weight_broadcast.rendezvous: "spg" | "mx_server"` (default `"spg"`). Flip the flag, no code paths diverge. +- **Pipeline replication is a client-side change.** After a rollout receives weights, it optionally re-registers itself as a source. The next rollout to poll discovers *either* the trainer or a replicated rollout, closer / less-loaded wins. Amplifies trainer NIC bandwidth in fan-out-heavy topologies. +- **Scratch-buffer diagnostic mode** reuses PI's transport but lands writes in isolated GPU tensors, then applies via `model.load_weights()`. Used for triangulating correctness issues like the KL drift in #2326. + +--- + +## 2. Timing Diagram β€” One `update_weights` Step + +Shows the MX-mediated path (`rendezvous: mx_server`). The SPG path is unchanged from PI #2326. + +```mermaid +sequenceDiagram + participant O as Orchestrator + participant T as Trainer rank k
(TransportPlan) + participant PUB as MxTrainingPublisher + participant MX as MX Server + participant RCV as MxRefitReceiver + participant R as NIXLWeightUpdateWorker
(rollout rank k) + participant V as vLLM engine + + Note over T: optimizer.step() complete + + O->>R: POST /pause + R-->>O: 200 OK (quiesced) + + par publish (trainer) + discover (rollout) + T->>PUB: prepare_slots(slots, agent_meta, version=N) + PUB->>MX: gRPC publish(model, agents[], slot_layout[], version=N) + MX-->>PUB: OK (mark version N publishable) + R->>RCV: init(model_name, worker_rank=k) + RCV->>MX: gRPC poll_for_source(model, worker_rank=k, min_version=N) + MX-->>RCV: agent_meta, slot_layout, source_id + end + + Note over T,R: (no SPG init needed; rendezvous complete via MX) + + T->>T: dist.barrier() (pre-write quiescence) + + loop per slot bucket (PI's chunked drain) + T->>T: pack slot β†’ GPU bucket, NIXL WRITE to rollout + T-->>R: NIXL RDMA WRITE (RoCE, rc_mlx5) + end + + T->>PUB: publish.finalize(version=N, done=true) + PUB->>MX: gRPC mark_version_ready(version=N) + R->>RCV: finalize() + RCV->>V: in-place refit complete
(or scratch apply via load_weights) + + opt pipeline_replication=true + RCV->>MX: gRPC publish_rollout_source(model, version=N, agent_meta) + Note over MX: subsequent rollouts poll
and may discover this rollout
as source + end + + opt next iteration β€” trainer about to mutate slots + T->>PUB: unpublish(version=N) + PUB->>MX: gRPC unpublish(version=N) + MX->>MX: wait for in-flight pulls to drain + MX-->>PUB: OK (safe to mutate) + end + + O->>R: POST /resume + R-->>O: 200 OK +``` + +### Observed per-step timing + +_These numbers are populated from the GB200 benchmark run described in Β§4. Until that run completes, cells are marked **TBD** and prior PI-reported numbers are noted for reference._ + +| Phase | PI SPG (12-node prod, reported) | MX rendezvous (GB200 2-node, measured) | MX + pipeline replication (GB200, projected) | +|-------|---------------------------------|----------------------------------------|----------------------------------------------| +| Rendezvous (SPG init vs MX poll) | ~0.8s (post-iter15 pre-write barrier) | **TBD** (target: ≀100 ms first poll, ≀20 ms steady-state) | **TBD** | +| `send_weights` / `receive_weights` (RDMA) | ~7.5 GB/s wire / 20 GB/s net | **TBD** (target: parity with PI β€” same transport) | **TBD** (target: linear scale with replica count) | +| `finalize` | ~0.1s | **TBD** | **TBD** | +| **Total `update_weights`** | β€” | **TBD** | **TBD** | + +Parity with PI on the data path is the acceptance criterion for the MX overlay β€” any regression means we've accidentally touched the hot path, which is not the design. + +--- + +## 3. ModelExpress Value Layer + +This section documents what the overlay changes relative to PI's PR #2326. + +### 3.1 SPG β†’ MX Server rendezvous + +**What SPG provides in #2326**: A fixed-world-size group over TCP, used to exchange NIXL agent metadata at init. Every participant must be present at the same time; adding/removing a rollout requires a full process restart. + +**What MX Server provides**: +- Each trainer rank calls `MxTrainingPublisher.publish(agent_meta, slot_layout, version)` once per step (gRPC). +- Each rollout calls `MxRefitReceiver.poll_for_source(model_name, worker_rank, min_version)` β€” returns the matching trainer rank's agent metadata and slot layout. +- Poll is idempotent and cache-friendly; rollouts can join mid-run, leave, or be restarted without affecting other participants. + +**Config surface**: + +```yaml +weight_broadcast: + type: nixl # use PI's transport + rendezvous: mx_server # instead of spg + mx_server_url: modelexpress-server.kavin.svc.cluster.local:8001 + model_name: "zai-org/GLM-4.5-Air-FP8" # example; see Β§4 for final selection +``` + +When `rendezvous: spg` (the default), behavior is 100% identical to PI's PR #2326. + +### 3.2 Pipeline replication β€” dynamic DAG of rollouts-as-sources + +Rollouts form a **dynamic DAG**, not a static star. Every `publish_rollout_source(version=N)` call adds a new parent edge available to unfinished rollouts. Every `poll_for_source(version=N)` gets load-balanced across the currently-available parent set (trainer + any rollouts that have already finalized version N). The DAG is built organically as receives complete; there is no precomputed topology. + +This is the same architectural pattern as TensorHub's Reference-Oriented Storage (ByteDance, April 2026; see `recovery/reinforcement learning/TensorHub_Analysis.md` for the design comparison). + +```yaml +weight_broadcast: + rendezvous: mx_server + pipeline_replication: true # default false +``` + +**DAG buildup over time** (12 rollouts, single trainer source for a given rank k): + +``` +t=0 Trainer publishes version N. + Sources for version N: {Trainer}. + MX Server DAG: Trainer ──→ (R0..R11 all polling) + +t=t0 Trainer β†’ R0 RDMA completes first. + R0 calls publish_rollout_source(version=N). + Sources: {Trainer, R0}. + MX Server DAG: Trainer ──→ (R1..R11 polling) + β”‚ + └─ R0 ──→ (next pollers can choose R0 or Trainer) + +t=t1 R1 and R2 pull in parallel from {Trainer, R0} (server load-balances). + Both finalize; publish_rollout_source(). + Sources: {Trainer, R0, R1, R2}. + Effective outbound: 4 NICs serving R3..R11. + +t=t2 R3..R6 finalize from {Trainer, R0, R1, R2}. + Sources: {Trainer, R0..R6}. + Effective outbound: 8 NICs serving R7..R11. + +t=t3 R7..R11 finalize. + All 12 rollouts hold version N. +``` + +**Bandwidth math**: A naive star with T trainer NICs serving R rollouts caps aggregate throughput at T Γ— per-NIC-BW, regardless of R. The DAG caps aggregate throughput at R Γ— per-NIC-BW (every GPU's outbound contributes once it has received). For R=12 and T=8 on the PI prod shape, this is a 1.5Γ— headroom; for R=64 on a future scale-out, it's 8Γ— headroom. + +**Load-balancing preference** (pipeline replication mode, distinct from Β§3.10 peer recovery): + +- Spread load evenly across currently-available sources (round-robin within a locality tier). +- Prefer same-rack sources over cross-rack to minimize inter-switch hops. +- Avoid overloading the trainer β€” once any rollout is available, weight the trainer lower in selection so its NIC stays free to seed new pushes. + +Contrast with Β§3.10 peer-recovery preference, which prefers same-node > same-rack > any > trainer-last (optimizing for recovery *latency* vs pipeline *throughput*). + +**Server-side state used** (shared with peer recovery in Β§3.10 β€” same index, two entry points): + +``` +sources_index : Map<(model, version, worker_rank), Set> +source_health : Map +source_load : Map // for load-balancing +``` + +**Current limit (TensorHub has, we don't yet)**: We call `publish_rollout_source()` after `finalize()` β€” i.e., once *all* slots have been received. TensorHub goes further and lets a partially-replicated rollout serve its *completed* slots while still receiving others. This deepens the pipelining further (rollout A's slot 0 can feed rollout B's slot 0 even while rollout A is still pulling slot 1 from the trainer). We've chosen post-finalize-only for the initial overlay to keep the correctness surface small; partial-replica serving is in the future-work list with a clear enablement path: + +1. Publisher exposes `publish_partial_source(version, completed_slots[])` that accepts a slot-id bitmap. +2. Server-side index keys on `(model, version, worker_rank, slot_id)` instead of `(model, version, worker_rank)`. +3. Receiver filters candidate sources per slot, composes multi-source pulls. + +Low-risk to add post-merge; no impact on the initial PR's success criteria. + +### 3.3 Mutability contract (publish / unpublish) + +PI's protocol currently relies on a `dist.barrier()` after `NIXL_READY` to ensure trainer ranks don't start writing while inference is still reading (iter15 fix). This works for synchronous push, but at async level β‰₯ 1 (pre-fetch next rollouts while current training step runs) the trainer will re-use slot buffers for the next `optimizer.step()` while rollouts may still be pulling the previous version. + +The MX overlay adds: +- `unpublish(version=N)` gRPC β€” trainer calls this just before buffer mutation. +- MX Server blocks `unpublish` until all in-flight pulls for version N have completed (tracked via heartbeat ACKs from rollouts). +- Publisher then signals the trainer that slot buffers are safe to mutate. + +### 3.4 Retention protocol + +MX Server enforces keep-latest-N versions per model (default N=2). If the last GPU-resident copy of a retained version is about to be unpublished, the server offloads that version to CPU memory as a fallback. Rollouts pulling a version that exists only on CPU pay a higher latency but don't fail. Prevents version loss under elastic churn; matches TensorHub retention semantics. + +### 3.5 Scratch-buffer diagnostic mode + +PI's direct-refit writes RDMA-received bytes directly into live vLLM parameter memory. The KL drift investigation in #2326 (27+ iterations, unresolved) narrowed the bug to "NIXL write mechanism itself" β€” write-ordering / visibility / tensor-identity hazards at the intersection of RDMA and live CUDA tensors. + +The MX overlay offers an alternate target: + +```yaml +weight_broadcast: + rendezvous: mx_server + transfer_mode: scratch # default: direct +``` + +In `scratch` mode: +- RDMA writes land in isolated scratch GPU tensors allocated by the receiver, not in live vLLM params. +- After the transfer completes, `model.load_weights()` applies them β€” the same code path NCCL uses. +- Bit-exact byte check passes in both modes (transport is identical); any KL divergence between `direct` and `scratch` isolates the bug to the direct-refit target layout (kernel format, stride, identity), *not* the NIXL mechanism. + +Memory cost is scratch-sized (~3.5 GB for 1.5B, ~15 GB for 7B, ~30 GB for 32B-class MoE). Intended for diagnostic runs, not production. + +### 3.6 Cross-framework reuse + +The same `MxTrainingPublisher` / `MxRefitReceiver` classes underpin `MxCheckpointEngine` in [verl](https://github.com/volcengine/verl). Seek `docs/RL/VERL_MX_OVERVIEW.md` for that integration's topology. Future NeMo-RL integration uses the same client. No framework-specific server code exists in MX Server. + +### 3.7 Expert-aware source tracking (MoE) + +For MoE models at expert parallelism EP > 1, each inference worker holds only a *subset* of experts. Broadcasting every expert to every worker (NCCL, filesystem) wastes `(EP - 1)/EP` of the bandwidth. MX's per-worker publishing model makes expert-selective transfer natural β€” every trainer EP rank publishes only its local experts; every inference EP rank pulls only its assigned experts from the matching source. + +**Server-side state** (primitives already defined, logic added in this overlay): + +- `WorkerMetadata.worker_rank` identifies the publishing EP rank. +- `SourceIdentity.expert_parallel_size` declares the source's EP degree. +- `TensorDescriptor.name` carries the expert index (`...experts.{N}.gate_proj.weight`). +- New server index `(model, version, expert_id) β†’ worker`, built incrementally as publishers announce slots. + +**New RPCs** in MX Server (added alongside the overlay): + +| RPC | Purpose | +|-----|---------| +| `publish_expert_ownership(source_id, expert_ids[])` | Trainer EP rank announces its assigned experts for the current version | +| `poll_for_expert_source(model, version, expert_id)` | Inference EP rank asks "who holds expert N?"; server returns matching source's agent meta | +| `list_experts(model, version)` | Diagnostic β€” returns the `expert_id β†’ worker` map | + +**Client-side** β€” once we adopt PI's `ExpertSlot` in Phase 1, the data is already produced by the publisher; the overlay just flushes it to MX Server instead of keeping it SPG-local. + +**Why this matters**: + +- **Bandwidth**: With EP=8 and 64 experts (matches both GLM-5 and Qwen2-57B-A14B), each inference worker needs 1/8th of expert bytes. Broadcast-based transports can't exploit this; MX can. +- **Future load-balancing**: The server-side index enables hot-expert migration (move frequently-used experts closer to their consumers), elastic expert redistribution, and dynamic expert pruning. +- **Minimal effort**: ~2 days server, ~1 day client, ~1 day tests β€” the primitives already exist. + +### 3.8 Scratch-buffer evolution path + +The scratch-buffer path in Β§3.5 is the correct *default* β€” same `model.load_weights()` code path NCCL uses, low correctness risk, clear A/B isolation. But it is not long-term feasible: at 32B the scratch approaches 35% of GB200 HBM; at 70B it no longer fits alongside a realistic KV cache. + +We document a 5-tier evolution so the design has a clear migration target as each correctness concern gets retired: + +| Tier | Approach | Scratch memory | Conversion cost | Correctness risk | Trigger to advance | +|------|----------|---------------|-----------------|------------------|-------------------| +| 0 | **Current** β€” scratch + `load_weights()` on receiver | Full model | High (every rollout) | Low | β€” | +| 1 | **ConversionSpec on trainer** + scratch on receiver (`load_weights` becomes plain copy) | Full model (kernel-format) | Paid once on trainer | Low | Adopt PI's `ConversionSpec`/`QuantizationSpec` (Phase 1) | +| 2 | **Streaming per-tensor** β€” receive one tensor, apply, free, next | Largest single tensor (~0.5–2 GB) | Low | Medium β€” breaks RDMA bucket batching; per-sub-bucket registration overhead | Need to run >32B before direct-refit is proven | +| 3 | **Tiled / rotating chunked scratch** β€” fixed cap (e.g., 2 GB) reused across tensors | Fixed cap (configurable) | Low | Medium β€” trainer/receiver must coordinate chunk cadence | Same as Tier 2 but prefers bounded memory to minimum-latency | +| 4 | **Direct refit** β€” RDMA writes land in live `param.data`; zero scratch (PI's approach, same code path as their PR #2326 direct mode) | Zero | Zero | **High** β€” live-param corruption, RDMA ordering, tensor identity hazards (see PI's unresolved KL drift in #2326) | KL drift root cause resolved; tensor layout stability contract proven | +| Tier-4 alt | **CPU offload fallback** β€” receive into pinned CPU RAM, DMA to GPU per layer | Zero GPU, full CPU RAM | PCIe copy per push | Low but slower | Only when GPU memory is the binding constraint | + +**Progression logic**: + +- Tiers 0 β†’ 1 is pure adoption of PI's trainer-side conversion; no new correctness risks. +- Tiers 1 β†’ {2, 3} is a memory optimization when scratch no longer fits. Both are valid; choose per deployment. +- Tier 4 is the terminal state but *only* after PI's KL drift investigation converges on a root cause. The `transfer_mode: scratch` diagnostic (Β§3.5) is the tool that gates this transition β€” if the drift persists in Tier 4 but not Tier 2/3, we've falsified direct-refit for that model family. + +**What the overlay PR implements now**: + +- Tier 0 β€” `transfer_mode: scratch` default on our overlay path. +- Tier 4 β€” available via `transfer_mode: direct` (imports PI's direct-refit code path unchanged). +- Tiers 1/2/3 β€” documented here as the migration target, not yet implemented. See `PRIMERL_POC_Next_Steps.md` Steps 9 + 10 for the tracked work items. + +### 3.9 Per-rank sharding-aware publishing (no rank-0 allgather) + +PI's `TransportPlan` + `ShardedSlot` / `GatheredSlot` / `ExpertSlot` design means every trainer rank publishes its *own local shard* directly to its matching inference rank. No rank ever holds the full unsharded model. This is inherited unchanged by the overlay β€” we don't reintroduce an allgather anywhere. + +**Contrast with the naive path** (what our pre-pivot MX POC on `kavink/mx-weight-broadcast` does, and what filesystem / NCCL-broadcast backends effectively do): + +``` +Before (naive / pre-pivot MX POC): + Rank 0 ──┐ + Rank 1 ──┼── allgather ──► Rank 0 holds full state_dict ──► 1Γ— NIXL WRITE ──► Inference + Rank 2 ─── (3.55 GB on 1.5B, 15 GB on 7B, + Rank 3 β”€β”€β”˜ 65 GB on 32B β€” does not fit!) + + Cost: 4x memory spike on rank 0, single NIC used, allgather + serializes all ranks, does not scale past ~30B. + +After (overlay on top of PI): + Rank 0 ── ShardedSlot 0 ── NIXL agent 0 ── RDMA WRITE ──► Inference rank 0 + Rank 1 ── ShardedSlot 1 ── NIXL agent 1 ── RDMA WRITE ──► Inference rank 1 + Rank 2 ── ShardedSlot 2 ── NIXL agent 2 ── RDMA WRITE ──► Inference rank 2 + Rank 3 ── ShardedSlot 3 ── NIXL agent 3 ── RDMA WRITE ──► Inference rank 3 + + Cost: zero memory spike, 4 NICs in parallel, each rank's transfer + is independent, scales linearly with rank count. +``` + +**Slot-type guarantees**: + +| Slot | Source shape | Publish pattern | Memory on any single GPU | +|------|-------------|-----------------|--------------------------| +| `ShardedSlot` | FSDP2-sharded (DTensor) | Each rank publishes `param.to_local()` | Only its local shard β€” same as training | +| `ExpertSlot` | EP-sharded (experts assigned per rank) | Each EP rank publishes its local experts | Only local experts | +| `GatheredSlot` | Small tensors (< 2 MiB threshold) | Rank 0 gathers and publishes as a bundle | Sum of small tensors (few MB) β€” handle-count optimization, not a correctness requirement | + +**HSDP** (hybrid-sharded data parallel) further restricts: only `dp_replicate == 0` runs the protocol at all. Non-primary replicas do not allocate NIXL slot buffers, do not register with MX Server, do not send on the wire. They `dist.barrier()` at the end to stay in lockstep with the primary's push. + +**Net effect on GB200 2-node shape** (4 trainer ranks Γ— 4 inference ranks): + +- Memory: 0 GB spike on any single rank regardless of model size. +- NIC utilization: 4 outbound streams in parallel on trainer, 4 inbound on rollout. Total bandwidth = sum of per-rank NICs, not capped at one NIC. +- Correctness: per-rank byte-exact byte-exact transfer (PI iter16 `nixl_diff.py` confirmed across all slot types). + +**Retiring Step 8**: `PRIMERL_POC_Next_Steps.md` Step 8 ("Eliminate rank-0 allgather β€” per-rank shard publishing") was one of our original P0 roadmap items. It is now absorbed by the pivot: adopting PI's `Slot` + `TransportPlan` gives us this behavior at Phase 1, with no additional MX-side code to write for the *publishing* topology itself. What remains on our side is (a) the MX rendezvous that routes rank-k β†’ rank-k discovery through the server instead of SPG, and (b) the server-side expert-aware index in Β§3.7. + +### 3.10 Peer recovery and source redundancy + +A rollout pod crashes and restarts. Without recovery support it must re-pull its shard from the trainer, consuming trainer NIC bandwidth that would otherwise serve new pushes. MX Server's source index β€” the same index populated by pipeline replication (Β§3.2) and per-rank publishing (Β§3.9) β€” lets a recovering rank discover live peers that already hold the current version and pull from the closest / least-loaded one. + +**Server-side state** (a small extension of what pipeline replication already requires): + +``` +sources_index : Map<(model, version, worker_rank), Set> +source_health : Map // TTL-driven liveness, e.g. 10 s +``` + +Every `publish()` or `publish_rollout_source()` call inserts into `sources_index`. Every gRPC RPC from a source refreshes `source_health`. A reaper removes sources whose heartbeats have expired. + +**Recovery API** β€” `poll_for_source` returns a ranked list, not a single source: + +```python +# Before +source: Source = mx_client.poll_for_source(model, version, worker_rank=k) + +# After (additive; old single-source call still works, wraps list[0]) +sources: list[Source] = mx_client.poll_for_sources( + model, version, worker_rank=k, + prefer=["same_node", "same_rack", "rollout_replica", "trainer"], + max_results=4, +) +receiver.receive_weights_with_fallback(sources) # tries [0], on failure falls through [1..] +``` + +**Preference ordering** (default; configurable): + +1. Same-node source (PCIe copy between processes on the same host, no network involved). +2. Same-rack source (cheaper RDMA hop). +3. Any rollout replica (preserves trainer NIC bandwidth). +4. Trainer rank k (last resort β€” trainer NIC is the scaling bottleneck for new pushes). + +**Why this is cheap in our design**: + +- The `sources_index` is *already* written to by every publish β€” no new write path. +- The `source_health` heartbeat is *already* needed to implement retention (Β§3.4) reaping. +- The only net-new is (a) returning a list, (b) client-side fallback loop on RDMA connect failure, (c) preference-ordered ranking in `poll_for_sources`. + +**Sharding-aware natural fit**: because publishing is per-rank (Β§3.9), a recovering rank k pulls *only its own shard* from peers that have rank k's shard β€” it does not need to know about ranks 0/1/... and does not re-pull the full state dict. Recovery cost scales with the shard size, not the full model. + +**What this explicitly does NOT do**: + +- **No event log / version replay.** Weights are a point-in-time snapshot, not a sequence of operations. A recovering rank always jumps directly to the current (or requested) version in one transfer. If it was down during versions 95-100, it does not apply updates 95 β†’ 96 β†’ 97 β†’ ...; it copies the state at version 100 from whichever peer has it. +- **No weight deltas in transit.** For standard RL (PPO/GRPO), every parameter changes on every `optimizer.step()`, so `weights[N] - weights[N-1]` is dense and compresses poorly. The bandwidth cost equals the full transfer; the memory cost doubles (sender and receiver both need `weights[N-1]`). Not worth the complexity for dense-update RL. See "Future" below for the structured-sparse case. + +**Retention + recovery interaction**: + +- If the retention protocol (Β§3.4) keeps latest-N versions, recovery can target any retained version. Typical use: default N=2 lets a newly-booted rollout catch up to "most recent ready" even if the trainer has already advanced one step. +- If the only live source for a retained version is about to heartbeat-out, the retention path triggers CPU offload before eviction, so the version survives until a fresh source publishes. + +**Future: weight deltas for structured-sparse RL** + +For LoRA-RL, adapter-only RL, or reward-frozen-policy-adapts where only a small submodule updates each step, an actual delta-transfer protocol becomes interesting: + +- Trainer publishes `delta_slot = weights[N].adapter - weights[N-1].adapter` (tiny compared to full weights). +- Receivers apply `param.data += delta` in place. +- MX Server tracks delta lineage per version. +- Peer recovery in this mode asks "give me all deltas from version 95 to 100" β€” which becomes a log-replay style recovery. + +This is a natural extension of the source index but requires: +- Per-slot base-version tracking. +- Delta application protocol on receiver. +- Fallback to full transfer if a delta is missing (e.g., all sources pruned by retention). + +Not implemented in the current overlay β€” surfaced here as a documented future direction in the Β§3.8 evolution path. + +--- + +## 4. Prototype on GB200 β€” Results + +### 4.1 Cluster + +| Resource | Value | +|----------|-------| +| Platform | GKE DGXCloud, GB200 ARM64 | +| Nodes | 2 (trainer + rollout), `hostNetwork=true` | +| GPUs | 8 Γ— GB200 (4 per node) | +| Node pools | `customer-gpu-w0e` (trainer), `customer-gpu-o7v` (rollout) | +| Fabric | RoCE v2, IMEX via DRA `kavin-compute-domain-channel` | +| UCX | v1.20.0 built from source, `self,sm,rc,cuda_copy,gdr_copy,tcp` | +| NIXL | v1.1.0 (main) | +| PyTorch | 2.6 + cu128 | +| vLLM | 0.18.1 (0.19.0 has ARM64 `resource_tracker` multiproc bug) | +| MX Server | `modelexpress-server.kavin.svc.cluster.local:8001` | +| Image | `nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:latest` (from overlay branch) | +| Base PR | [PrimeIntellect-ai/prime-rl#2326](https://github.com/PrimeIntellect-ai/prime-rl/pull/2326) @ commit TBD | + +### 4.2 Model (tiered) + +PI's PR #2326 targets an internal GLM-5 MoE FP8 model that isn't publicly available. We exercise the same PR #2326 code paths (ShardedSlot, GatheredSlot, ExpertSlot, ConversionSpec, QuantizationSpec FP8 2D/3D, HSDP primary-replica) using publicly available models of similar architecture, in two tiers: + +| Tier | Model | Params | PR #2326 paths exercised | Fit on 2-node GB200 | Role | +|------|-------|--------|--------------------------|---------------------|------| +| **T1** | `Qwen/Qwen2.5-7B` BF16 | 7.6B dense | ShardedSlot, GatheredSlot, TransportPlan, MX rendezvous, pipeline replication, elastic join | βœ… Comfortable β€” known-good on our existing POC | **Primary first pass**: validates MX overlay end-to-end | +| **T2** | `Qwen/Qwen2-57B-A14B-Instruct` FP8 | 57B total / 14B active, 64 experts | T1 + ExpertSlot, QuantizationSpec FP8 2D + 3D, non-layer specs | Tight β€” FSDP=4 trainer, TP=4 inference, ~25 GB/GPU combined | **Stretch / same PR if time permits**: matches PI GLM-5 expert topology closely (64 experts) | +| T2 fallback | `mistralai/Mixtral-8x7B-Instruct-v0.1` FP8 | 47B total / 13B active, 8 experts | Same code paths as T2, fewer experts | Comfortable | If Qwen2-57B has integration issues | + +**Why two tiers**: T1 validates the MX overlay (rendezvous, elastic join, pipeline replication) on hardware we know works. If T2 hits issues in MoE routing or FP8 conversion, T1 results alone are sufficient to demonstrate the MX value proposition. If T1 passes, T2 is primarily a matter of authoring the model-specific `ConversionSpec` table (and benefiting from PI's already-proven GLM spec patterns). + +**Why Qwen2-57B-A14B over `zai-org/GLM-4.5-Air`**: Qwen2-57B has 64 experts (direct match to PI's GLM-5 expert count), well-documented FP8 variants, leaves headroom on 2-node GB200. GLM-4.5-Air at 106B is borderline feasible but higher risk for the demo. + +Selection confirmed at first-boot feasibility test in W1 (see `PRIMERL_MX_OVERLAY_PR_PLAN.md` Β§6.2). + +### 4.3 Deployment shape + +``` +Node 1 (customer-gpu-w0e, IP 10.0.0.83) +β”œβ”€ StatefulSet: prime-rl-mx-trainer-0 +β”‚ β”œβ”€ 4Γ— FSDP2 trainer ranks +β”‚ β”œβ”€ NIXLWeightBroadcast + TransportPlan (PI) +β”‚ └─ MxTrainingPublisher (overlay) +└─ 4Γ— NIXL agent (rc_mlx5), per-rank NIC pin + +Node 2 (customer-gpu-o7v, IP 10.0.15.225) +β”œβ”€ StatefulSet: prime-rl-mx-rollout-{0,1,2,3} +β”‚ β”œβ”€ NIXLWeightUpdateWorker (PI) +β”‚ β”œβ”€ MxRefitReceiver (overlay) +β”‚ └─ vLLM engine (TP=4 across these 4 pods, or 1Γ— TP=4 with 4 subprocess workers β€” TBD) +└─ 4Γ— NIXL agent (rc_mlx5) + +Optional elastic rollout: +└─ StatefulSet: prime-rl-mx-rollout-extra (launched mid-run at step 3) + +MX Server + Redis (kavin namespace, gRPC) +``` + +### 4.4 Benchmark scenarios + +Three scenarios run back-to-back on the same config so results are directly comparable. + +| # | Name | Config | What it measures | +|---|------|--------|------------------| +| A | SPG baseline | `rendezvous: spg`, `transfer_mode: direct`, `pipeline_replication: false` | PI's unmodified path on our 2-node shape. Establishes absolute baseline. | +| B | MX rendezvous | `rendezvous: mx_server`, `transfer_mode: direct`, `pipeline_replication: false` | Same transport, MX replaces SPG. Expected parity on data-path timing; adds dynamic discovery. | +| C | MX + pipeline + elastic | `rendezvous: mx_server`, `transfer_mode: direct`, `pipeline_replication: true`, launch 5th rollout at step 3 | Demonstrates two MX-only capabilities: (a) bandwidth amplification via rollout-as-source, (b) elastic mid-run join. | +| D | MX + scratch-buffer diagnostic | `rendezvous: mx_server`, `transfer_mode: scratch` | KL-drift triangulation β€” same RDMA, different target buffer. Useful if A/B/C uncovers a correctness issue. | +| E | MX + peer recovery | `rendezvous: mx_server`, `pipeline_replication: true`; `kubectl delete pod rollout-2` at step 5; pod restart triggers peer recovery | Demonstrates (a) recovering rank pulls from a surviving peer rather than the trainer, (b) recovery completes within one `update_weights` cycle, (c) trainer NIC bandwidth uninterrupted during recovery. | + +### 4.5 Metrics to capture + +_Populated after the run; target values are derived from PI's reported 12-node numbers and our verl POC parity runs._ + +Scenarios are numbered A-E. DAG observability (Β§4.5.5) applies wherever pipeline replication is enabled. + +#### 4.5.1 Weight-sync phase timing + +| Metric | Target (based on PI prod) | A SPG | B MX rendezvous | C MX + pipeline | D MX + scratch | +|--------|---------------------------|-------|-----------------|-----------------|----------------| +| Rendezvous wall-clock | ≀ 0.5 s first, ≀ 50 ms steady | TBD | TBD | TBD | TBD | +| Pre-write barrier | ~0.8 s (PI iter15) | TBD | TBD | TBD | TBD | +| Per-slot RDMA WRITE | parity with PI | TBD | TBD | TBD | TBD | +| Total `update_weights` | 1.0-1.5 s on our shape | TBD | TBD | TBD | TBD | + +#### 4.5.2 RDMA throughput + +| Metric | Target | A | B | C | D | +|--------|--------|---|---|---|---| +| Wire BW per trainer NIC | ~7.5 GB/s | TBD | TBD | TBD | TBD | +| Aggregate net BW | ~20 GB/s (4 NICs) | TBD | TBD | TBD | TBD | +| Aggregate with pipeline replication | > 20 GB/s effective | β€” | β€” | TBD | β€” | + +#### 4.5.3 MX Server round-trip latencies (B/C only) + +| Operation | Target | Measured | +|-----------|--------|----------| +| `publish_agent` (trainer) | < 5 ms | TBD | +| `poll_for_source` (rollout, warm) | < 10 ms | TBD | +| `poll_for_source` (rollout, cold) | < 50 ms | TBD | +| `unpublish` (blocking until drain) | < 20 ms after last pull ACK | TBD | +| `publish_rollout_source` (pipeline) | < 5 ms | TBD | + +#### 4.5.4 Elastic-join demo (C only) + +| Metric | Target | +|--------|--------| +| Time from extra rollout pod `Ready` to first weight received | < 2 Γ— (one training step) | +| Impact on other rollouts' weight-update time | None (MX Server load-balances, existing rollouts unaffected) | + +#### 4.5.5 DAG fan-out observability (C only) + +MX Server logs every `poll_for_source` response with the source-id it selected. Derived metrics: + +| Metric | Target | Measured | +|--------|--------|----------| +| First-rollout receive time (trainer-only source set) | matches A | TBD | +| Second-rollout receive time (post-first-rollout-publish) | ≀ first-rollout / 2 (log-fan-out) | TBD | +| Average sources-per-poll across all 5 rollouts in C | β‰₯ 2 (DAG engaged), ideally trending to R/2 as pulls complete | TBD | +| Trainer NIC utilization during the tail of receives | decreasing (DAG shifts load away from trainer) | TBD | +| Max concurrent pulls served by any single source | ≀ 3 (load-balancing effective) | TBD | + +These metrics validate the DAG pattern empirically. If sources-per-poll stays at 1 across all rollouts, pipeline replication isn't engaging β€” it's a regression indicator. + +#### 4.5.6 Peer recovery demo (E only) + +| Metric | Target | +|--------|--------| +| Source chosen for recovering rollout | A surviving rollout peer (not trainer) β€” verified via server access log | +| Time from recovered pod `Ready` to first weight received | ≀ 1 Γ— (one training step) | +| Trainer NIC bandwidth during recovery vs steady-state | within noise (< 5% dip) β€” DAG preference avoids loading trainer | +| Impact on other rollouts' weight-update time | None | + +#### 4.5.7 Training quality (B vs A vs D) + +KL divergence vs NCCL baseline, measured over 20 training steps per scenario: + +| Scenario | Target | Measured | +|----------|--------|----------| +| A SPG direct-refit | Matches PI observation (drifts past step ~7) | TBD β€” reference | +| B MX rendezvous direct-refit | **Expected: same drift as A** (data path unchanged) | TBD | +| D MX scratch | **Expected: bounded like NCCL** (if true, isolates bug to direct-refit target) | TBD β€” key diagnostic | + +This is the KL-drift triangulation data. If B drifts and D does not, the bug is in live-param-refit layout/identity, not NIXL. If both drift, the bug is deeper. Either result is valuable to the PI investigation. + +### 4.6 Results summary (to be filled) + +_To be written after the benchmark run. Expected headline numbers:_ + +- **Data-path parity**: MX rendezvous shows no wall-clock regression vs SPG on `update_weights` timing β€” same transport, same bytes. +- **Dynamic discovery**: MX rendezvous setup < 100 ms first call, < 20 ms steady-state. SPG equivalent requires process restart. +- **Pipeline replication**: aggregate effective bandwidth scales with rollout count beyond trainer NIC cap. +- **Elastic join**: a 5th rollout joins a 4-rollout setup mid-run and receives weights on the next push without affecting the other four. +- **Diagnostic value**: scratch-mode run provides the first bit-exact isolation of the PI KL drift to either transport or target layout. + +--- + +## 5. How to Run + +### 5.1 Prerequisites + +- GB200 GKE cluster, `kavin` namespace, `customer-gpu-w0e` + `customer-gpu-o7v` node pools available. +- MX Server running at `modelexpress-server.kavin.svc.cluster.local:8001` (from `k8s/deployments/modelexpress-server.yaml` in this repo). +- `tsh` auth for `nvcr.io/nvidian/dynamo-dev/` image registry. +- HuggingFace token with access to selected GLM model variant. + +### 5.2 Build the overlay image + +```bash +cd /path/to/prime-rl +git fetch origin nixl-weight-transfer +git checkout kavink/mx-on-nixl # our overlay branch +docker buildx build --platform linux/arm64 \ + -f docker/Dockerfile.mx-arm64 \ + -t nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:latest \ + --push . +``` + +Dockerfile layers PI's FP8/NIXL dependencies on top of our known-good GB200 runtime stack (UCX 1.20, NIXL main, PyTorch 2.6+cu128, vLLM 0.18.1, SDPA flash_attn shim for ARM64). + +### 5.3 Deploy + +```bash +kubectl apply -f k8s/prime-rl-mx-on-nixl/trainer.yaml +kubectl apply -f k8s/prime-rl-mx-on-nixl/rollout.yaml +# for scenario C only: +kubectl apply -f k8s/prime-rl-mx-on-nixl/rollout-extra.yaml # launch at step 3 +``` + +### 5.4 Run the benchmark matrix + +Orchestrated by `scripts/run-benchmark-matrix.sh` in the overlay branch: + +```bash +./scripts/run-benchmark-matrix.sh \ + --scenarios A,B,C,D \ + --model zai-org/GLM-4.5-Air-FP8 \ + --steps-per-scenario 20 \ + --output results/gb200-$(date +%Y%m%d-%H%M%S)/ +``` + +Output directory contains: per-scenario log files, `update_weights` timings CSV, MX Server access logs, KL-divergence traces (W&B offline dump), per-phase barrier timings, NIXL bandwidth reports from UCX. + +### 5.5 Generate the results tables + +```bash +python scripts/build-results-tables.py \ + --run-dir results/gb200-/ \ + --output docs/RL/PRIMERL_MX_OVERVIEW.md \ + --replace-section "## 4. Prototype on GB200 β€” Results" +``` + +Regenerates Β§4.5 and Β§4.6 of this document from the raw run data. + +--- + +## 6. Relationship to PR #2326 + +This design is an **overlay**, not a fork. The intended contribution shape: + +1. Adopt PR #2326 as the transport foundation β€” no reimplementation. +2. Publish a PR-on-PR against `PrimeIntellect-ai/prime-rl:nixl-weight-transfer` that adds: + - New helper: `src/prime_rl/utils/mx_rendezvous.py`. + - Env-var-gated dispatch switch in `NIXLWeightBroadcast.__init__` and `NIXLWeightUpdateWorker.init_nixl_transfer`: when `PRIME_RL_MX_RENDEZVOUS` is set, call `discover_spg_coordinator` to get SPG host/port from MX Server instead of the static `config.host`/`config.port`. + - Optional pipeline replication call in `NIXLWeightUpdateWorker.update_weights_from_path` receive tail (gated on `PRIME_RL_MX_PIPELINE_REPLICATION=1`). + - `k8s/` demo manifests + `run.sh` for scenarios A/B/C. + - `docker/Dockerfile.mx-on-nixl` layering UCX 1.19.x + NIXL 0.10.1 + MX client on top of PI's `Dockerfile.cuda`. + - `benchmarks/scripts/parse_mx_metrics.py` log aggregator. +3. SPG remains the default; opt-in only. No config surface changes in v0.1 β€” env-var-gated keeps the PR small. +4. When #2326 merges to `main`, retarget our PR base to `main` automatically. + +**Status**: Draft PR opened at **[PrimeIntellect-ai/prime-rl#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343)** targeting the `nixl-weight-transfer` branch. 3 commits, 11 files, +1508/-2. GB200 benchmark results pending image build completion. + +See `recovery/reinforcement learning/PRIME_INTELLECT_PR2326_Analysis.md` in the internal planning tree for the full Phase 0-3 plan and messaging guidance, and `OVERLAY_PR_EXECUTION_STATE.md` for live session state + restore instructions. From 861bac2087c2874d8b123dd85a4524338bdf3dc5 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Thu, 23 Apr 2026 13:06:19 -0700 Subject: [PATCH 15/25] docs(RL): refine VERL_MX_OVERVIEW for native nixl alignment and catalog value MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Rename Β§2 to frame MX as additive on verl's NIXL checkpoint engine - Document native nixl ring path positively; optional MX catalog + star - Add catalog benefits (balancing, multi-source, publish/retire, retention) - Fix RDMA READ source/destination wording; remove PRIME-RL references - Tone: prefer native nixl vs consider mx; align metrics cross-refs Made-with: Cursor Signed-off-by: Kavin Krishnan --- docs/RL/VERL_MX_OVERVIEW.md | 105 +++++++++++++++++++++++++++++------- 1 file changed, 87 insertions(+), 18 deletions(-) diff --git a/docs/RL/VERL_MX_OVERVIEW.md b/docs/RL/VERL_MX_OVERVIEW.md index d5ced9b2..31b2d230 100644 --- a/docs/RL/VERL_MX_OVERVIEW.md +++ b/docs/RL/VERL_MX_OVERVIEW.md @@ -3,13 +3,13 @@ **Last Updated**: April 2026 **Status**: E2E working β€” cross-node RDMA weight transfers via `MxCheckpointEngine` on 2Γ— GB200 nodes (GKE). -This document covers how ModelExpress (MX) plugs into [verl](https://github.com/volcengine/verl) for RL post-training weight synchronization. It walks through the component design, the Ray actor integration, the `CheckpointEngine` surface, and the GB200 prototype results. +This document covers how ModelExpress (MX) plugs into [verl](https://github.com/volcengine/verl) for RL post-training weight synchronization. It walks through the component design, **how MX relates to verl’s native `nixl` checkpoint engine**, the Ray actor integration, the `CheckpointEngine` surface, and the GB200 prototype results. --- ## 1. Design Overview -verl is a Ray-orchestrated RL framework. Its `CheckpointEngine` plugin system is the seam where MX slots in. `MxCheckpointEngine` replaces the default `naive` sync (process-local copy) or the built-in `nixl` ring engine with a **star topology over RDMA**, coordinated by the MX Server. +verl is a Ray-orchestrated RL framework. Its `CheckpointEngine` plugin system is the seam where MX slots in. Teams can use the default **`naive`** sync (process-local copy), verl’s native **`nixl`** engine (NIXL ring over RDMA), or the optional **`mx`** backend: same API and same NIXL data plane, with MX adding an **MX Server + Redis catalog** for discovery and a **star** trainerβ†’rollout wiring instead of a ring. ### What MX adds to verl @@ -83,13 +83,86 @@ graph TB ### Key ideas - **MX Server stores metadata only** β€” tensor names, GPU memory addresses, NIXL agent blobs, version numbers. It never touches weight bytes. -- **The heavy transfer is a one-sided RDMA READ** from the rollout's NIXL agent into the trainer's GPU memory, going GPU-direct over RoCE via `rc_mlx5`. -- **Star topology, not a ring.** verl's built-in NIXL engine uses a ring; MX uses the server as a central rendezvous, which is simpler to reason about and sets up future pipeline replication (rollouts can become secondary sources). -- **Bucketed transfer preserves shapes.** Unlike the PRIME-RL POC (which needs scratch buffers), verl's `CheckpointEngine` passes a tensor generator with names and shapes. MX packs them into GPU buckets and the receiver pulls them out by offset β€” no reshape tricks required. +- **The heavy transfer is a one-sided RDMA READ initiated on the rollout side**: each rollout NIXL agent **pulls** weight bytes **from** the trainer's registered GPU send bucket **into** its own local recv bucket (GPU-direct over RoCE via `rc_mlx5`). Logically weights still move **trainer β†’ rollout**; the NIC operation is a **read** whose *source* is trainer VRAM and *destination* is rollout VRAM. +- **Star vs ring wiring.** verl’s native **`nixl`** engine chains trainer and rollout ranks in a **ring** (each rank knows `prev` / `next`). **`mx`** keeps the same bucket + NIXL READ pattern but connects each rollout **directly** to the trainer, with the MX Server as a **rendezvous** for who to read fromβ€”useful when you want catalog-driven discovery or multiple future sources (e.g. rollouts that also publish). +- **Bucketed transfer preserves shapes.** verl's `CheckpointEngine` passes a tensor generator with names and shapes. MX packs them into GPU buckets and the receiver pulls them out by offset using per-bucket metadata (no separate layout side-channel beyond what the engine already carries). --- -## 2. Timing Diagram β€” One `update_weights` Step +## 2. What MX adds on top of verl’s native `nixl` checkpoint engine + +verl ships **`NIXLCheckpointEngine`** (`verl/checkpoint_engine/nixl_checkpoint_engine.py`, `backend: nixl`) as the **native GPU RDMA path** inside the same `CheckpointEngine` abstraction MX uses. It is mature, self-contained, and a strong default when a **single Ray job** wires trainer and rollout ranks with **driver-computed** `prev` / `next` links. + +**`MxCheckpointEngine` (`backend: mx`) does not replace that stack** β€” it **reuses** the same **bucket packing**, **ZMQ per-bucket metadata**, and **NIXL `initialize_xfer("READ", …)`** pull semantics. MX **adds** an **MX Server + Redis** catalog so consumers **discover** sources and versions via **gRPC**, and uses a **star** attach (each rollout READs from the trainer) instead of chaining through intermediate ranks. + +### What verl’s native `nixl` engine already provides + +| Aspect | Behavior | +|--------|----------| +| **Registration** | `@CheckpointEngineRegistry.register("nixl")`. Selected with `actor_rollout_ref.rollout.checkpoint_engine.backend=nixl`. | +| **Buffers** | Two byte buckets per rank (`send_buf`, `recv_buf`), registered with NIXL. On CUDA, verl often allocates via **CuPy** then views as `torch.uint8` to avoid registration issues with expandable PyTorch segments. | +| **Metadata between ranks** | **`NixlAgent`** wraps `nixl_agent` and uses **ZMQ** (`PULL` on each agent, `PUSH` to peers) to ship **per-bucket** `bucket_meta` (tensor name, shape, dtype, byte offset) plus a notify key β€” the same bucket-descriptor pattern **`mx`** uses peer-to-peer between trainer and rollouts once peers are connected. | +| **RDMA operation** | **`ReadOperation`**: `initialize_xfer("READ", local_descs, remote_descs, remote_agent, …)` then `transfer` / `check_xfer_state` until `DONE`. The initiator’s **local** buffer is the **destination**; **remote** is the **source** (trainer VRAM when reading from the trainer). | +| **Trainer send path** | Only **trainer rank 0** runs `send_weights`. Other trainer ranks consume the weight generator and no-op (same pattern **`mx`** inherits). Rank 0 fills a bucket and uses **`ReadableOperation`**: it tells the **next** agent in the ring that its send buffer is readable; the next agent performs the **RDMA READ** from rank 0. | +| **Rollout receive path** | Each rollout **`receive_weights`**: **READ from `prev_agent`**, slice tensors out of the bucket, **`yield`** to verl. If the rank has a **`next_agent`**, it also **`ReadableOperation`**s the same bucket metadata downstream so the next rank can READ β€” **pipeline along a chain**. | +| **Topology** | **`build_topology`** builds a **ring** of size **`rollout_world_size + 1`**: rank `0` = trainer head; ranks `1…N` = rollouts. Trainer has **next** only; last rollout has **prev** only; middle ranks have **prev** and **next**. | +| **Discovery / versioning** | The driver gathers each rank’s `NixlAgentMetadata` and installs **fixed `prev` / `next`** for that step β€” ideal when membership is stable and ordering is fully determined by Ray ranks. | +| **Standalone mode** | Same requirement as **`mx`** for non-`naive` engines: rollout uses **`CheckpointEngineWorker`** (disaggregated GPUs). | + +```mermaid +graph TB + subgraph ring["verl native NIXLCheckpointEngine β€” ring (conceptual)"] + T["Trainer rank 0
send_weights only"] + R1["Rollout 1
READ prev β†’ forward"] + R2["Rollout 2
READ prev β†’ forward"] + RN["Rollout N
READ prev only"] + T --> R1 + R1 --> R2 + R2 --> RN + end + + style T fill:#533483,stroke:#e94560,color:#fff + style R1 fill:#2e7d32,stroke:#66bb6a,color:#fff + style R2 fill:#2e7d32,stroke:#66bb6a,color:#fff + style RN fill:#2e7d32,stroke:#66bb6a,color:#fff +``` + +Weights still flow **trainer β†’ rollouts**; intermediate rollouts **repeat** the bucket so the last rank receives the full model without the trainer fanning out to everyone directly. + +### Optional MX layer: same NIXL moves, different control plane + +| Dimension | Native **`nixl`** (verl) | With **`mx`** | +|-----------|---------------------------|---------------| +| **Topology** | **Ring**: each bucket walks rank `0 β†’ 1 β†’ … β†’ N` with optional forward between rollouts | **Star**: each rollout **READs directly** from the trainer’s bucket (trainer completes **N** READ completions β€” fan-out on the trainer NIC) | +| **Who owns β€œwho do I read?”** | **Driver + Ray ranks**: `prev` / `next` from gathered `NixlAgentMetadata` | **MX Server + Redis** catalog: source identity, **version / step**, optional **worker_rank**, room for **multiple registered sources** | +| **Rendezvous across jobs / clusters** | Topology is **defined by this job’s rank graph** | Consumers **resolve** a source via **gRPC** to a stable catalog service, then attach with **NIXL** as today | +| **Extensibility** | New routing ideas are expressed in **verl driver / topology helpers** | Many policies can live in **catalog + client** without expanding core ring math for every deployment | + +#### What an MX catalog enables (additive) + +A Redis-backed MX catalog records **which processes publish which weight version** and the **NIXL metadata** needed to attach. That is **optional**; when you need it, it supports: + +- **Global view for load spreading** β€” Readers can ask the catalog **which source should serve this read** when **several replicas** expose the same version, instead of always following one fixed ring order. + +- **Load- and locality-aware steering** β€” Entries are **per-source identities**, so policy can prefer **less busy** or **network-closer** holders as the fleet grows, without each node probing the entire cluster. + +- **More than one read source over time** β€” Processes that have **finished** receiving a version can register as **additional sources**; the catalog can list **multiple holders** so bandwidth can spread through the pool (trainer plus peers). + +- **Publish / retire coordination** β€” A central record can track **in-flight reads** vs **writer reuse** of GPU pages so trainers retire a version before mutating shared buffersβ€”coordination that is harder when each rank only knows ring neighbors. + +- **Retention with elastic membership** β€” A **cluster-wide** picture of which versions exist and where a **last readable copy** remains before unregister helps elastic scale-out / scale-in. + +- **Stable service names** β€” Trainer and rollout attach to **DNS-backed catalog endpoints** even when Ray placement or node pools change between runs. + +**Prefer native `nixl`** when a single Ray job, ring latency, and driver-wired `prev` / `next` are exactly what you want β€” nothing else required. + +**Consider `mx`** when you want **catalog-driven discovery**, **explicit versions**, **direct trainerβ†’each-rollout** reads, or a path to **multi-source** and **policy-driven** routing **without** growing custom ring topology code in-tree for every site. + +Both backends use the same **`CheckpointEngine`** actor model and keep **bulk weight bytes off Ray**; both use **NIXL/RoCE** for the actual GPU transfers. + +--- + +## 3. Timing Diagram β€” One `update_weights` Step ```mermaid sequenceDiagram @@ -154,7 +227,7 @@ For the same model and cluster, the default `naive` engine averages **~1.6s** (i --- -## 3. ModelExpress and the Ray Actor Design +## 4. ModelExpress and the Ray Actor Design verl's runtime is a web of Ray actors. Understanding where `MxCheckpointEngine` lives inside that web is the key to the integration. @@ -269,7 +342,7 @@ One class, two roles, distinguished by which actor type instantiates it. Both si --- -## 4. ModelExpress and the Checkpoint Engine +## 5. ModelExpress and the Checkpoint Engine verl's `CheckpointEngine` ABC is the small, well-defined plugin surface that makes MX a drop-in. @@ -328,7 +401,7 @@ graph LR | Method | Trainer side | Rollout side | |--------|--------------|--------------| -| `prepare()` | Allocate pinned GPU send bucket, register with NIXL agent, publish agent metadata + tensor layout to MX Server via gRPC | Allocate GPU recv bucket, register with NIXL, call `MxClient.poll_for_source(model_name)` to get the trainer's agent blob | +| `prepare()` | Allocate registered GPU send bucket, register with NIXL agent, publish agent metadata + tensor layout to MX Server via gRPC | Allocate GPU recv bucket, register with NIXL, call `MxClient.poll_for_source(model_name)` to get the trainer's agent blob | | `build_topology()` | Driver-side utility. Produces `(trainer_agent β†’ [rollout_agents])` star mapping | Same (called on driver) | | `init_process_group()` | For each rollout rank: `nixl_agent.add_remote_agent(rollout_meta)` | `nixl_agent.add_remote_agent(trainer_meta)`; ZMQ PULL socket bound on free port, advertised to trainer | | `send_weights(iter)` | Consume `(name, tensor)` generator. Pack tensors into the GPU bucket at known offsets. Send `BucketDesc{name, shape, dtype, offset, nbytes}` over ZMQ PUSH. Block until rollout signals ACK | β€” | @@ -355,18 +428,13 @@ actor_rollout_ref: skip_sleep_wake: true # avoid vLLM multiproc sleep/wake crash on ARM64 ``` -### What differs from PRIME-RL +### How verl uses `MxCheckpointEngine` on the tensor path -| Concern | PRIME-RL | verl / MxCheckpointEngine | -|---------|----------|---------------------------| -| Plugin point | Custom `WeightBroadcast` ABC + vLLM worker extension | `CheckpointEngine` ABC (native, already has NIXL and NCCL siblings) | -| Shape handling | Scratch buffers + safetensors header reshape | Bucket carries `(name, shape, dtype)` β€” no reshape needed | -| Fused params (Q/K/V, gate/up) | Rely on `model.load_weights()` to fuse from HF names | Trainer publishes already-in-target-format buckets; rollout passes through to `load_weights` | -| Allgather on trainer | Rank 0 gathers FSDP shards before publish | FSDP shards are packed per-rank; star topology fans out to rollout ranks | +verl streams `(name, tensor)` pairs through the `CheckpointEngine` API. `MxCheckpointEngine` packs those into registered GPU buckets; per-bucket metadata carries **name, shape, dtype, and byte offset** so `receive_weights` can slice views and hand them to **`ServerAdapter.load_weights`** without an extra layout file. That is the same bucket pattern as verl’s native `nixl` engine; **`mx`** swaps ring wiring for **catalog + star** discovery as described in Β§2. --- -## 5. Prototype on GB200 β€” Results +## 6. Prototype on GB200 β€” Results ### Cluster @@ -434,6 +502,7 @@ MX Server + Redis (kavin namespace, reachable from both nodes over gRPC) | Avg step time | ~8.1-8.8s | | Avg `update_weights` (MX) | **~1.25s** | | Avg `update_weights` (naive baseline, hybrid mode) | ~1.6s | +| verl native `nixl` (same step, standalone) | *Not benchmarked in this doc* β€” same NIXL READ + buckets as **`mx`**, with **ring** topology and **no** MX Server | | Throughput | 135-163 tokens/sec | | Transport | NIXL / UCX `rc_mlx5` (RoCE RDMA) | | Data path | Cross-node GPUβ†’GPU (no CPU staging, no filesystem) | @@ -470,6 +539,6 @@ This proved publisher/receiver correctness before moving to cross-node. 1. **MX works on Ray.** The `CheckpointEngine` plugin surface is sufficient to express a star-topology RDMA transfer with server-mediated discovery. 2. **Cross-node RoCE RDMA is real.** `update_weights` at 1.25s for a 3 GB model on a 2-node GB200 cluster is consistent with UCX `rc_mlx5` over RoCE and beats the in-process naive baseline even before we tune the bucket size. -3. **The ARM64 path is painful but survivable.** All the image work is in `docker/Dockerfile.mx-arm64` and the compat shim, and is shared with (and borrowed from) the PRIME-RL POC. +3. **The ARM64 path is painful but survivable.** All the image work is in `docker/Dockerfile.mx-arm64` and the compat shim, aligned with other GB200 MX + vLLM container iterations on the same stack. 4. **Standalone rollout is the production shape.** Hybrid/colocated mode is useful for debugging but cannot drive any non-naive checkpoint engine β€” true for NIXL, NCCL, and MX alike. From 6ccec5d79ae1c7a9efd41ad29870d08cf920169c Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Thu, 23 Apr 2026 21:15:15 -0700 Subject: [PATCH 16/25] docs(RL): add Path B native MX design as alternative to PI overlay Companion to PRIMERL_MX_OVERVIEW.md (Path A). Documents an MX-shaped weight broadcast design for prime-rl that uses PI's NIXL transport as the data plane but exposes ModelExpress's traditional API surface (model-agnostic, server-mediated, scratch-buffer-default, cross-framework-portable) instead of PI's per-model conversion_specs + slot system. Key positioning: - Path A = strict overlay on PI's API. Smallest diff. In flight. - Path B = native MX shape. Larger diff. Staged design only. - Both ship as discriminator options on weight_broadcast.type (existing nixl + new mx coexist). Documents: - Why we'd consider B (per-model spec wall, KL drift inheritance, elastic shapes, cross-framework alignment). - File footprint (~600 LOC new, ~70 LOC modified). - Migration path (toml-only for users). - Pitch sequence for PI conversation. - Inflection points to pivot from A to B. Made-with: Cursor Signed-off-by: Kavin Krishnan --- docs/RL/PRIMERL_MX_NATIVE_DESIGN.md | 276 ++++++++++++++++++++++++++++ 1 file changed, 276 insertions(+) create mode 100644 docs/RL/PRIMERL_MX_NATIVE_DESIGN.md diff --git a/docs/RL/PRIMERL_MX_NATIVE_DESIGN.md b/docs/RL/PRIMERL_MX_NATIVE_DESIGN.md new file mode 100644 index 00000000..be19e393 --- /dev/null +++ b/docs/RL/PRIMERL_MX_NATIVE_DESIGN.md @@ -0,0 +1,276 @@ +# PRIME-RL Γ— ModelExpress β€” Native API Design (Path B) + +**Status**: Design proposal (no code yet) +**Last Updated**: April 2026 +**Companion to**: `PRIMERL_MX_OVERVIEW.md` (the overlay-on-PI design, "Path A") + +This document describes a **native MX-shaped weight broadcast backend** for PRIME-RL that uses [PI's NIXL transport](https://github.com/PrimeIntellect-ai/prime-rl/pull/2326) as the bytes-on-wire data plane, but exposes **ModelExpress's traditional API surface** to PRIME-RL β€” model-agnostic, server-mediated, scratch-buffer-default, cross-framework. + +It's the alternative to "Path A" (strict overlay on PI's API). Path A is in flight as draft PR [PrimeIntellect-ai/prime-rl#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343); Path B is staged here as the design we'd advocate for if Path A hits friction or if the team wants the broader strategic value. + +--- + +## 1. Why This Doc Exists + +Path A's overlay strategy preserves PI's NIXL API surface end-to-end and replaces only the SPG rendezvous with MX Server discovery. It's small, cooperative, easy to merge β€” and it inherits every PI design constraint: + +- **Per-model `conversion_specs()` requirement.** PI's `TransportPlan` only supports models that PI has authored a spec table for: today, just `glm_moe_dsa`. Plain Qwen3, Llama, Mixtral, anything else needs a spec table written. We just discovered this when scenario A's trainer crashed with `'FSDPQwen3ForCausalLM' object has no attribute 'conversion_specs'` and had to monkey-patch HF Qwen3 to unblock. +- **Direct refit only β€” KL drift class of bugs is in scope.** PI's PR is currently blocked by 27+ iterations of KL drift investigation. Their byte-exact transport is correct; the drift comes from concurrent UCX writes into live vLLM `param.data`. Inheriting their target-buffer model means inheriting that bug surface. +- **Static, startup-time tensor registration.** `Slot{Sharded,Gathered,Expert}` assume fixed tensor shapes registered with NIXL once at init. Elastic workloads (LoRA-RL with dynamic adapter add/remove, frozen-policy-adapts variants, growing context-len rollouts) don't fit this model cleanly. +- **prime-rl-only.** PI's `NIXLWeightBroadcast` lives in `prime_rl/trainer/rl/broadcast/nixl.py`; cross-framework portability would require us to copy the design into verl + future NeMo-RL. + +Path B is the "what if we redesigned the prime-rl-side weight broadcast around MX's shape, but kept PI's UCX/NIXL setup as-is for the bytes" answer. The data plane is identical (same RDMA bytes on wire, same per-NIC bandwidth, same `rc_mlx5` transport, same per-rank NIC pin). Only the control plane and tensor ABI change. + +--- + +## 2. What MX-Traditional Looks Like (Recap) + +ModelExpress has a consistent API across the integrations we've shipped (verl `MxCheckpointEngine`, our existing PRIME-RL `ModelExpressWeightBroadcast` on `kavink/mx-weight-broadcast`): + +```python +# Trainer side +publisher = MxTrainingPublisher( + agent_name="trainer-rank-0", + device_id=local_rank, + server_url="modelexpress-server.kavin.svc.cluster.local:8001", +) +publisher.initialize() # NIXL agent up, gRPC connected +publisher.publish_weights(state_dict, step=N) # per training step +# ...later, before optimizer.step() reuses buffers: +publisher.unpublish(version=N) # mutability contract + + +# Inference side +receiver = MxRefitReceiver( + model_name="...", + worker_rank=k, + device_id=local_rank, +) +receiver.initialize(model_tensors=scratch_or_live_dict) # NIXL register receive buffers +source = receiver.poll_for_source(min_version=N) # discover trainer +receiver.receive_weights(source) # RDMA pull (or accept WRITE) +# scratch path: vllm_model.load_weights(scratch_iter) +# direct path: receive lands directly in vllm_model.param.data +``` + +Salient differences from PI's design: + +| Concern | PI (Path A overlay inherits) | MX-traditional (Path B exposes) | +|---------|------------------------------|--------------------------------| +| Discovery | SPG static rendezvous | gRPC `MxClient` (publish, list_sources, get_metadata) | +| Source identity | Implicit rank pairing | Content-addressed `mx_source_id = sha256(SourceIdentity)` | +| Per-model contract | `model.conversion_specs(layer_idx)` | None β€” model-agnostic; publishes live `state_dict` tuples | +| Tensor ABI | Fixed `Slot{...}` registered at startup | Per-step `(name, shape, dtype, gpu_addr)` published; receiver re-registers if shape drifts | +| Receive target | Direct WRITE into live `param.data` | **Scratch buffer + `model.load_weights()`** by default; opt-in direct refit | +| Quantization | First-class `ConversionSpec`/`QuantizationSpec` | Trainer pre-quantizes if needed; MX is dtype-agnostic | +| Lifecycle | rendezvous β†’ register Γ— N β†’ write Γ— N per startup | publish/poll/receive/unpublish per step; no static startup registration | +| Versioning | Implicit step counter | First-class `extra_parameters.training_step` | +| Mutability contract | Implicit (root cause of KL drift?) | Explicit `unpublish()` with drain | +| Cross-framework | prime-rl only | Same client runs in verl, prime-rl, future NeMo-RL | +| vLLM integration | Worker extension only | Worker extension OR `WeightTransferEngine` plugin (Step 11) | + +--- + +## 3. Path B Architecture + +Same NIXL data plane as PI; different prime-rl-side abstractions on top of it. + +``` +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ prime-rl trainer process β”‚ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ β”‚ MxWeightBroadcast (new, in prime_rl/trainer/rl/broadcast/) β”‚ β”‚ +β”‚ β”‚ ─ implements PI's WeightBroadcast ABC β”‚ β”‚ +β”‚ β”‚ ─ delegates the data path to MxTrainingPublisher β”‚ β”‚ +β”‚ β”‚ ─ delegates the control path to MxClient gRPC β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ +β”‚ β”‚ β”‚ β”‚ +β”‚ β–Ό data β–Ό metadata β”‚ +β”‚ MxTrainingPublisher MxClient(server_url=...) β”‚ +β”‚ (modelexpress) (modelexpress) β”‚ +β”‚ β”‚ β”‚ β”‚ +β”‚ β–Ό β–Ό β”‚ +β”‚ NixlAgentWrapper β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ (PI's existing class) β”‚ MX Server (gRPC + Redis) β”‚ β”‚ +β”‚ β”‚ β”‚ ─ source registry β”‚ β”‚ +β”‚ β–Ό post_write β”‚ ─ poll_for_source β”‚ β”‚ +β”‚ UCX rc_mlx5 RDMA β”‚ ─ pipeline replication β”‚ β”‚ +β”‚ (same wire as PI) β”‚ ─ retention + versioning β”‚ β”‚ +β”‚ β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ +β”‚ β–Ό β”‚ +β”‚ ConnectX-7 NIC ─── RoCE ───► rollout NIC β”‚ +β”‚ β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ + +β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” +β”‚ prime-rl inference (vLLM) worker β”‚ +β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ +β”‚ β”‚ MxWeightUpdateWorker (new) β”‚ β”‚ +β”‚ β”‚ ─ vLLM worker extension β”‚ β”‚ +β”‚ β”‚ ─ MxRefitReceiver delegate β”‚ β”‚ +β”‚ β”‚ ─ default: scratch buffer + model.load_weights() β”‚ β”‚ +β”‚ β”‚ ─ opt-in: direct refit into live param.data β”‚ β”‚ +β”‚ β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ β”‚ +β”‚ β”‚ β”‚ +β”‚ β–Ό β”‚ +β”‚ MxRefitReceiver (modelexpress) β”‚ +β”‚ β”‚ β”‚ +β”‚ β–Ό β”‚ +β”‚ NixlAgentWrapper (PI's existing class) β”‚ +β”‚ β”‚ β”‚ +β”‚ β–Ό β”‚ +β”‚ accepts WRITE from trainer (same as PI) β”‚ +β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜ +``` + +### What we adopt unchanged from PI's PR + +These are real engineering wins; we'd be foolish to redo them: + +- `NixlAgentWrapper` (UCX agent setup, register_tensor, prep_local/prep_remote, post_write, wait, drain) +- `pin_ucx_rail` (per-rank NIC pinning that's the difference between 4.8 GB/s and 7.5 GB/s) +- `classic_cuda_pool` (allocator workaround for `expandable_segments` + `ibv_reg_mr` "local protection" bug) +- The runtime image stack (UCX 1.19, NIXL 0.10.1, ARM64 quirks) +- Pre-write SPG barrier / quiescence pattern (the iter15 fix; we'd implement equivalent via MX Server fence RPC) +- HSDP primary-replica gate (only `dp_replicate == 0` runs the protocol) + +These come into MX as either direct re-export or via a `prime_rl.utils` import; no need to fork. + +### What we replace with MX-shape + +| PI component | Our replacement | +|--------------|-----------------| +| `NIXLWeightBroadcast` | `MxWeightBroadcast` β€” same prime-rl ABC (`WeightBroadcast`), different internals | +| `NIXLWeightUpdateWorker` | `MxWeightUpdateWorker` β€” same vLLM worker_extension_cls slot | +| `TransportPlan` | Replaced. Per-step iteration over `state_dict` tuples instead of slot table | +| `model.conversion_specs(layer_idx)` | Removed entirely. Trainer publishes whatever's in `state_dict`; vLLM's `model.load_weights()` does the HFβ†’kernel format on inference (its tested code path) | +| `Slot{Sharded,Gathered,Expert}` | Removed. The per-rank publishing semantics come from `MxTrainingPublisher`'s `worker_rank` field; tiny-tensor coalescing is a NIXL register optimization not a slot type | +| SPG rendezvous (`StatelessProcessGroup`) | Removed. Discovery is `MxClient.poll_for_source`; per-step barrier is `MxClient.fence` (new) or fall back to a small SPG over MX-discovered endpoints | +| `ConversionSpec`/`QuantizationSpec` (FP8 quantize on trainer) | Optional on trainer side. Adopted as `MxQuantizer` if user wants FP8 fast path; default is BF16 passthrough. Inference-side decompresses via vLLM's existing FP8 loader, not via NIXL post-processing. | + +### What we add that PI doesn't have + +These are MX traditional features that translate naturally into prime-rl now that we're not constrained by PI's slot abstraction: + +- **Pipeline replication** (TensorHub-style DAG): rollout publishes itself as a secondary source after receive; MX Server load-balances new pollers across trainer + rollouts. +- **Peer recovery**: a restarting rollout pod pulls from any surviving peer (via `poll_for_sources` ranked by health/locality), not always from trainer. +- **Versioning + retention**: keep-latest-N versions on MX Server; rollouts can request a specific version or "latest stable." +- **Mutability contract**: explicit `unpublish(version)` before trainer reuses slot buffers; server blocks until in-flight pulls drain. Directly addresses PI's KL drift hypothesis (write ordering / live param visibility). +- **Cross-framework**: same `MxTrainingPublisher`/`MxRefitReceiver` already proven in verl and our existing PRIME-RL POC. Path B is the alignment: prime-rl + verl + future NeMo-RL share the same MX abstractions. +- **Scratch-buffer default**: receive lands in isolated GPU tensors; vLLM `model.load_weights()` applies them via its tested NCCL-equivalent path. KL-drift class of bugs falls away. Direct refit becomes opt-in for users who measure and accept the correctness risk. +- **Elastic shapes**: per-step `(name, shape, dtype)` publishing means LoRA-RL, dynamic adapters, growing context lengths all work without re-init. + +--- + +## 4. Concrete File Footprint + +What changes in `KavinKrishnan/prime-rl:kavink/mx-on-nixl` to ship Path B (relative to current Path A overlay): + +### New files + +| File | Purpose | Est. LOC | +|------|---------|----------| +| `src/prime_rl/trainer/rl/broadcast/mx.py` | `MxWeightBroadcast` β€” new prime-rl `WeightBroadcast` impl. ~3 methods: `__init__` (publisher + agent setup), `broadcast_weights(model, step)` (publish state_dict tuples, signal orchestrator), `shutdown()` | ~300 | +| `src/prime_rl/inference/vllm/worker/mx.py` | `MxWeightUpdateWorker` β€” vLLM worker_extension_cls. Delegates to `MxRefitReceiver` for receive, then `model.load_weights(scratch_iter)` for apply. Direct-refit opt-in via env. | ~250 | +| `src/prime_rl/configs/...` | New `MxWeightBroadcastConfig` discriminator on `WeightBroadcastConfig` union. `type: "mx"` | ~60 | +| `docs/weight-transfer-modelexpress.md` | Already requested by `@mikasenghaas` in #2326 review. Documents all four backends (filesystem, nccl, nixl, mx) with selection guidance | ~250 | + +### Modified files + +| File | Change | Est. LOC | +|------|--------|----------| +| `src/prime_rl/configs/trainer.py`, `orchestrator.py`, `rl.py` | Add `MxWeightBroadcastConfig` to discriminated union; thread through unify-mode for the `rl` entrypoint | +40 | +| `src/prime_rl/trainer/rl/broadcast/__init__.py` | Add `mx` dispatch in `setup_weight_broadcast()` | +15 | +| `src/prime_rl/inference/vllm/server.py` (or where worker_extension_cls is wired) | `mx` value selects `MxWeightUpdateWorker` | +10 | +| `src/prime_rl/utils/client.py` (TRANSFER_READY marker) | Touch for `mx` backend too (already protocol-agnostic per the existing review feedback) | +5 | + +### Files we don't touch + +PI's NIXL backend stays in place. Users who want PI's path get `type: "nixl"`; users who want ours get `type: "mx"`. No conflict between the two backends β€” they coexist as discriminator options. + +### What MX-side code we add (in `ai-dynamo/modelexpress`) + +Mostly already exists from the verl + existing PRIME-RL POCs. Net new: + +- `MxTrainingPublisher.publish_weights_via_nixl(state_dict, step, agent)` β€” adapt our existing publisher to use PI's `NixlAgentWrapper` directly (same NIXL bytes as PI, just driven from our publisher class). +- `MxRefitReceiver.receive_weights_via_nixl(source, agent, target)` β€” same. +- `MxClient.fence(model, version, world_size)` β€” server-mediated barrier (replaces SPG barrier per step). Optional; can fall back to SPG over MX-discovered endpoints if we want to minimize server changes. + +Plus the server-side capabilities tracked in `PRIMERL_MX_OVERVIEW.md` Β§3 (pipeline replication index, retention, peer recovery preference ordering). + +--- + +## 5. Migration Path for Users + +Users currently on PI's `type: "nixl"`: + +```toml +# Before (PI's path) +[weight_broadcast] +type = "nixl" +host = "..." +port = 29502 +inference_world_size = N +backends = ["UCX"] + +# After (MX-native path), no rebuild required +[weight_broadcast] +type = "mx" +mx_server_url = "modelexpress-server.kavin.svc.cluster.local:8001" +model_name = "..." +inference_world_size = N +# transfer_mode = "scratch" # default; opt-in "direct" for the speed/correctness tradeoff +# pipeline_replication = false # default; opt-in for fan-out scale +``` + +Same image, same UCX setup, same NIXL transport β€” just a different config discriminator and ~600 LOC of new code in prime-rl. PI's existing `nixl` backend stays available. + +For ourselves: we delete the monkey-patch on Qwen3 (`qwen3_specs_patch.py`) β€” it's no longer needed because Path B doesn't require per-model conversion specs. Path A's overlay survives as-is for users who want to stick with PI's transport API exactly. + +--- + +## 6. Pitch Sequence for the Design Conversation + +If we propose Path B to PI on the existing draft PR (or in a new sibling PR): + +1. **Acknowledge PI's transport win directly.** "Your NIXL transport is excellent β€” UCX setup, classic_cuda_pool, pin_ucx_rail, FP8 ConversionSpec are all correct. Path B keeps every byte of that." + +2. **Frame the divergence as scope.** "We've been running ModelExpress as a metadata + elasticity layer across verl and our internal PRIME-RL POC for several months. The MX-shape API is model-agnostic, server-mediated, scratch-buffer-default β€” it solves problems your `Slot`/`ConversionSpec` design doesn't address (cross-framework, elastic shapes, KL drift via scratch path, retention, pipeline replication)." + +3. **Show the demo evidence.** "Path A overlay (already up as #2343) proved the metadata layer works on top of your transport. Path B extends that with a native MX API β€” here's the design doc, here's a draft diff." + +4. **Make the cohabitation explicit.** "Path B doesn't replace `type: nixl`. It adds `type: mx` as a sibling discriminator. Users pick. PI's GLM-5 production keeps using `nixl`; users who want a model-agnostic / cross-framework / scratch-default path use `mx`. Same UCX runtime image, same NIXL bytes, same `pin_ucx_rail` discipline β€” different control plane." + +5. **Address the KL drift directly.** "The drift you're chasing in iter22-27 is consistent with concurrent live-param writes. The MX scratch-buffer path isn't subject to that bug surface because writes land in isolated tensors, then `model.load_weights()` applies them via vLLM's NCCL-equivalent code path. Happy to A/B this on your iter26/27 config β€” same NIXL bytes, just a different target buffer. If your drift disappears in scratch mode, that's diagnostic data even if you keep `nixl` as default." + +--- + +## 7. Inflection Points to Pivot A β†’ B + +- **PI's PR stalls on KL drift > 2 weeks** without a root-cause fix landing. Path B's scratch-buffer default ships independent of that investigation. +- **Reviewer pushback on Path A's overlay shape** β€” env-var-as-config or monkey-patching Qwen3 specs gets pushback that suggests a cleaner design is wanted. +- **Need for elastic / LoRA-RL features** that PI's slot system can't accommodate. Path B's per-step publish handles dynamic shapes naturally. +- **Cross-framework alignment becomes a strategic priority** β€” leadership wants "one weight broadcast story across prime-rl + verl + future frameworks" rather than per-framework redesigns. +- **Scenario A's first NIXL run reveals the same KL drift on Qwen3** as PI saw on GLM-5. That's a strong empirical signal that direct-refit-into-live-params is the bug surface, not GLM-specific. Path B's scratch-default becomes the obvious fix. + +--- + +## 8. Decisions Pending + +| Decision | Default if not made | +|----------|---------------------| +| Ship Path B alongside Path A or as a follow-up PR? | Follow-up PR; Path A goes first | +| MX-mediated barrier (`MxClient.fence`) or SPG-over-MX-discovered endpoints? | SPG-over-MX-discovered for v0.1 (smaller server change) | +| `MxWeightBroadcast` lives in `prime_rl/trainer/rl/broadcast/mx.py` (in-tree) or as a plugin from `modelexpress` package? | In-tree mirroring PI's nixl.py pattern | +| Scratch vs direct refit default | Scratch (correctness-safe). Direct opt-in via `transfer_mode: direct` | +| Pipeline replication default | Off. Opt-in via `pipeline_replication: true` | + +--- + +## 9. Status + +- Path A draft PR: [PrimeIntellect-ai/prime-rl#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343) β€” Qwen3 conversion_spec patch in flight, image rebuild done, deploy pending tsh refresh. +- Path B: design only (this doc). No code yet. Ready to author when an inflection point above triggers. +- Tracking: `recovery/reinforcement learning/PRIME_INTELLECT_PR2326_Analysis.md` for the strategic comparison; `PRIMERL_MX_OVERVIEW.md` for the Path A overlay design. + +This doc is the artifact we'd reference in the conversation with PI if Path A doesn't carry the day on its own. From 0e10aba401e76f16a7f0b3e745379e0810e93d34 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Fri, 24 Apr 2026 09:24:42 -0700 Subject: [PATCH 17/25] docs(RL): clarify v0.1 overlay scope in PRIMERL_MX diagrams MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Fix two diagrams in PRIMERL_MX_OVERVIEW.md that implied MX replaces PI #2326's entire control plane, when v0.1 is env-var-gated and only swaps SPG coordinator discovery. Component diagram: - Label every imported PI element "(PI, unchanged)" and every MX element "(MX overlay, env-var gated)" so reader can see the split. - Re-draw control-plane edges as dotted green (MX additions): publish_spg_coordinator (boot), mark_version_ready (per step), discover_spg_coordinator (boot), publish_rollout_source (optional). - Add the SPG 2-round all_gather_obj edge between trainer and rollout labeled as PI code β€” previously missing entirely, so readers could think MX alone wired up NIXL agents. - Data-plane edge labeled "trainer -> rollout recv buffer" to match PI's actual WRITE semantics instead of a generic bidirectional arrow. Timing diagram: - Wrap flow in three tinted bands: green boot-time MX discovery (the only v0.1 change) vs purple per-step SPG metadata rounds + RDMA WRITE (PI, unchanged). - Move discovery out of the per-step par block into a "once per run" boot-time block β€” register_coordinator / discover_spg_coordinator are init-time, not update-time ops. - Add the SPG 2-round all_gather_obj step that was missing: round 1 exchanges agent_meta + slot_layout + recv-buffer descriptors, round 2 exchanges per-slot xfer descriptors. - Relabel the RDMA step as "NIXL RDMA WRITE -> rollout's recv buffer" so the write direction is explicit. - Trigger for unpublish changed from vague "next iteration" to "async_level >= 1" β€” ties the mutability contract to the actual concurrency regime that needs it. - Add explicit legend calling out MX-added vs PI-unchanged. No design changes; docs-only clarification so the diagrams match what the overlay code actually does. Made-with: Cursor Signed-off-by: Kavin Krishnan --- docs/RL/PRIMERL_MX_OVERVIEW.md | 103 +++++++++++++++++++-------------- 1 file changed, 61 insertions(+), 42 deletions(-) diff --git a/docs/RL/PRIMERL_MX_OVERVIEW.md b/docs/RL/PRIMERL_MX_OVERVIEW.md index d381a5c8..63510bc1 100644 --- a/docs/RL/PRIMERL_MX_OVERVIEW.md +++ b/docs/RL/PRIMERL_MX_OVERVIEW.md @@ -34,12 +34,12 @@ PR #2326 gives PRIME-RL a bit-exact RDMA weight transport built on NIXL/UCX over ```mermaid graph TB subgraph driver["Driver Β· CPU Β· orchestrator process"] - orch["RL Orchestrator
(existing)"] + orch["RL Orchestrator
(PI, unchanged)"] httpapi["/pause /resume /update_weights
(vLLM WeightTransferEngine endpoints)"] orch --> httpapi end - subgraph mx_meta["Metadata Plane Β· CPU"] + subgraph mx_meta["Metadata Plane Β· CPU Β· MX overlay adds"] mx["MX Server
(gRPC)"] redis[("Redis")] mx --> redis @@ -48,30 +48,31 @@ graph TB subgraph trainer["Trainer node Β· FSDP2 + optimizer"] direction TB tw["Trainer ranks Γ— N
(dp_shard Γ— cp)"] - tp["NIXLWeightBroadcast
+ TransportPlan
(PI's code, unchanged)"] - pub["MxTrainingPublisher
(MX overlay)"] - tnixl(["NIXL Agent Γ— N"]) + tp["NIXLWeightBroadcast
+ TransportPlan
(PI, unchanged)"] + pub["MxTrainingPublisher
(MX overlay,
env-var gated)"] + tnixl(["NIXL Agent Γ— N
(PI)"]) tw --> tp - tp -->|slot registry| pub + tp -.->|register_coordinator
once at boot| pub tp --> tnixl end subgraph rollout["Rollout nodes Β· vLLM TP"] direction TB - cew["NIXLWeightUpdateWorker Γ— M
(PI's code, unchanged)"] - rcv["MxRefitReceiver
(MX overlay)"] - rnixl(["NIXL Agent Γ— M"]) + cew["NIXLWeightUpdateWorker Γ— M
(PI, unchanged)"] + rcv["MxRefitReceiver
(MX overlay,
env-var gated)"] + rnixl(["NIXL Agent Γ— M
(PI)"]) vllm["vLLM engine Γ— M
(live params)"] cew --> rcv cew --> rnixl - cew -. "in-place RDMA WRITE
or scratch-buffer stage" .-> vllm + cew ==> vllm end - pub -- "gRPC publish_agent
slots + agent_meta + version" --> mx - rcv -- "gRPC poll_for_source
(model_name, worker_rank)" --> mx - mx -- "trainer agent_meta
slot layout + NIXL blob" --> rcv - tnixl <== "NIXL RDMA WRITE
RoCE Β· rc_mlx5
(PI transport, unchanged)" ==> rnixl - rcv -. "publish_rollout_source
(pipeline replication)" .-> mx + pub -. "gRPC publish_spg_coordinator
(boot) Β· mark_version_ready (per step)" .-> mx + rcv -. "gRPC discover_spg_coordinator
(boot)" .-> mx + rcv -. "gRPC publish_rollout_source
(Β§3.2 pipeline replication, optional)" .-> mx + + cew <== "SPG all_gather_obj Γ— 2 rounds
(agent_meta, slot_layout, xfer descs)
PI code, unchanged" ==> tp + tnixl <== "NIXL RDMA WRITE
trainer β†’ rollout recv buffer
RoCE Β· rc_mlx5 (PI, unchanged)" ==> rnixl style driver fill:#1a1a2e,stroke:#533483,color:#e0e0e0 style mx_meta fill:#1a1a2e,stroke:#4caf50,color:#e0e0e0 @@ -91,7 +92,12 @@ graph TB style redis fill:#162447,stroke:#533483,color:#e0e0e0 ``` -**Legend**: Green boxes = MX/NIXL additions (metadata plane + overlay client classes). Purple = existing PRIME-RL / vLLM / PI-PR-#2326 components. The trainer-to-rollout NIXL arrow is the exact same RDMA WRITE path PI introduced; MX does not touch the data plane. +**Legend**: + +- **Purple boxes + solid purple edges** β€” existing PRIME-RL / vLLM / PI #2326 code the overlay imports and uses as-is. +- **Green boxes + dotted green edges** β€” MX additions: MX Server, overlay client classes, gRPC control-plane calls. +- **Data plane** (trainer NIXL ↔ rollout NIXL, solid double-edge) is **100% PI** β€” MX does not see or touch weight bytes. +- **SPG metadata rounds** (trainer ↔ rollout double-edge between `TransportPlan` and `NIXLWeightUpdateWorker`) are **100% PI** β€” MX only swaps how participants *find* the SPG coordinator; the two `all_gather_obj` rounds themselves are untouched. ### Key ideas @@ -105,62 +111,75 @@ graph TB ## 2. Timing Diagram β€” One `update_weights` Step -Shows the MX-mediated path (`rendezvous: mx_server`). The SPG path is unchanged from PI #2326. +Shows the MX-mediated path (`rendezvous: mx_server`). The SPG path is unchanged from PI #2326. The overlay's v0.1 scope is **coordinator discovery only** β€” once the SPG coordinator is found via MX, PI's existing 2-round `all_gather_obj` metadata exchange and `TransportPlan` RDMA WRITE run bit-identically. ```mermaid sequenceDiagram participant O as Orchestrator - participant T as Trainer rank k
(TransportPlan) + participant T as Trainer ranks
(NIXLWeightBroadcast) participant PUB as MxTrainingPublisher participant MX as MX Server participant RCV as MxRefitReceiver - participant R as NIXLWeightUpdateWorker
(rollout rank k) + participant R as NIXLWeightUpdateWorker
(rollout ranks) participant V as vLLM engine + rect rgba(76,175,80,0.08) + Note over T,MX: Boot-time (once per run) β€” late-bound SPG discovery + T->>PUB: register_coordinator(model, host, port) + PUB->>MX: gRPC publish_spg_coordinator(model, host:port) + R->>RCV: init(model_name, worker_rank=k) + RCV->>MX: gRPC discover_spg_coordinator(model) + MX-->>RCV: SPG host:port + end + Note over T: optimizer.step() complete O->>R: POST /pause R-->>O: 200 OK (quiesced) - par publish (trainer) + discover (rollout) - T->>PUB: prepare_slots(slots, agent_meta, version=N) - PUB->>MX: gRPC publish(model, agents[], slot_layout[], version=N) - MX-->>PUB: OK (mark version N publishable) - R->>RCV: init(model_name, worker_rank=k) - RCV->>MX: gRPC poll_for_source(model, worker_rank=k, min_version=N) - MX-->>RCV: agent_meta, slot_layout, source_id + rect rgba(83,52,131,0.10) + Note over T,R: SPG 2-round metadata exchange (PI code, unchanged) + par per step + T->>T: StatelessProcessGroup(host, port, rank, world_size) + R->>R: StatelessProcessGroup(host, port, rank, world_size) + end + T-->>R: Round 1 β€” NIXL agent_meta, slot_layout,
recv-buffer descriptors + T-->>R: Round 2 β€” per-slot xfer descriptors + T->>T: dist.barrier() (PI iter15 pre-write quiescence) end - Note over T,R: (no SPG init needed; rendezvous complete via MX) - - T->>T: dist.barrier() (pre-write quiescence) - - loop per slot bucket (PI's chunked drain) - T->>T: pack slot β†’ GPU bucket, NIXL WRITE to rollout - T-->>R: NIXL RDMA WRITE (RoCE, rc_mlx5) + rect rgba(83,52,131,0.10) + Note over T,R: Data plane β€” PI RDMA WRITE (unchanged) + loop per slot bucket (TransportPlan.drain) + T-->>R: NIXL RDMA WRITE β†’ rollout's recv buffer
(RoCE, rc_mlx5) + end end - T->>PUB: publish.finalize(version=N, done=true) - PUB->>MX: gRPC mark_version_ready(version=N) - R->>RCV: finalize() - RCV->>V: in-place refit complete
(or scratch apply via load_weights) + par finalize + T->>PUB: mark_version_ready(N) + PUB->>MX: gRPC mark_version_ready(model, version=N) + R->>RCV: finalize() + RCV->>V: direct refit into live params
or scratch β†’ model.load_weights() + end opt pipeline_replication=true RCV->>MX: gRPC publish_rollout_source(model, version=N, agent_meta) - Note over MX: subsequent rollouts poll
and may discover this rollout
as source + Note over MX: future pollers may be
steered to this rollout
(see Β§3.2 DAG) end - opt next iteration β€” trainer about to mutate slots + opt async_level β‰₯ 1 β€” trainer about to mutate slots T->>PUB: unpublish(version=N) - PUB->>MX: gRPC unpublish(version=N) - MX->>MX: wait for in-flight pulls to drain - MX-->>PUB: OK (safe to mutate) + PUB->>MX: gRPC unpublish(model, version=N) + MX->>MX: wait for in-flight pulls
(version N) to drain + MX-->>PUB: OK (buffers safe to mutate) end O->>R: POST /resume R-->>O: 200 OK ``` +**Legend**: Green-tinted block = one-time boot path (MX discovery β€” the only control-plane change the v0.1 overlay makes). Purple-tinted blocks = per-step flow that's **unchanged from PI #2326** (SPG metadata rounds, barrier, RDMA WRITE). MX Server hooks at finalize + optional blocks (`publish_rollout_source`, `unpublish`) are additive β€” absent when the corresponding feature flag is unset. + ### Observed per-step timing _These numbers are populated from the GB200 benchmark run described in Β§4. Until that run completes, cells are marked **TBD** and prior PI-reported numbers are noted for reference._ From 45341f34817e0100fea21dc3d1a0a1b4bd41c2d5 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Fri, 24 Apr 2026 12:34:33 -0700 Subject: [PATCH 18/25] docs(RL): backfill PRIMERL_MX_OVERVIEW with Scenario A GB200 results MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Scenario A (PI NIXL direct refit on Qwen3-0.6B, 2 trainer ranks Γ— 1 inference rank) completed all 20 RL training steps end-to-end on GB200 on April 24. Update the doc to reflect reality. Changes: - Status line: "Design complete ... Metrics below are TBD" β†’ "Scenario A green end-to-end" with concrete numbers in the lead paragraph (596 MB/push, 310 slots, 100% success, draft PR link). - Β§2 observed per-step timing table: replace PI-reported 12-node projections with measured scenario A numbers (20/20 steps, 5.1 s avg, 596 MB bucket, 310 slots published, rank-0 writes all 310 / rank-1 writes 197). B and C cells remain "pending" until the next session flips the env vars. - Β§2 caveat added explaining that throughput comparison vs PI's prod numbers is distorted by our SDPA fallback (ARM64 image ships a flash_attn import stub; real kernels are a P1 follow-up). NIXL transfer itself is unaffected; step-time inflation is on the training-compute side. - Β§4.5 metrics matrix: populate scenario A column with measured values; mark B/C/D/E as "pending" rather than "TBD" to signal the scenarios are scoped + instrumented, just not yet run. - Β§4.6 results summary: rewrite from "to be written after the benchmark run" to a concrete list of what Scenario A proved (foundation validated, per-rank sharding-aware works as PI designed, overlay is structurally correct). Add pointers to the nine blocker fixes documented in OVERLAY_PR_EXECUTION_STATE.md. Keep the B/C/D/E expectations section framed as "next-session targets" so readers know what to expect the doc to grow into. No diagram changes; April 24 timing + component diagram fixes from 461be85 remain the current shape. Made-with: Cursor Signed-off-by: Kavin Krishnan --- docs/RL/PRIMERL_MX_OVERVIEW.md | 69 ++++++++++++++++++++-------------- 1 file changed, 41 insertions(+), 28 deletions(-) diff --git a/docs/RL/PRIMERL_MX_OVERVIEW.md b/docs/RL/PRIMERL_MX_OVERVIEW.md index 63510bc1..5d5c15c6 100644 --- a/docs/RL/PRIMERL_MX_OVERVIEW.md +++ b/docs/RL/PRIMERL_MX_OVERVIEW.md @@ -1,7 +1,7 @@ # ModelExpress Γ— PRIME-RL β€” Design Overview -**Last Updated**: April 2026 -**Status**: Design complete; prototype overlay on top of [PrimeIntellect-ai/prime-rl#2326](https://github.com/PrimeIntellect-ai/prime-rl/pull/2326) targeting GB200 (ARM64, GKE). Metrics sections below are populated as the benchmark run produces data. +**Last Updated**: April 24, 2026 +**Status**: **Scenario A green on GB200** β€” overlay validated end-to-end against PI's `nixl-weight-transfer` branch ([#2326](https://github.com/PrimeIntellect-ai/prime-rl/pull/2326)). 20/20 RL training steps with real NIXL RDMA pushes (~596 MB/step, 310 slots, 100% `/update_weights` success). Scenarios B (MX rendezvous engaged) and C (pipeline replication) are next β€” MX overlay code is deployed behind env-var gates (`PRIME_RL_MX_RENDEZVOUS`, `PRIME_RL_MX_PIPELINE_REPLICATION`) and awaits flip + measurement. Scenarios D (scratch-buffer diagnostic) and E (peer recovery) are designed but not yet run. Draft PR: [#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343). This document covers how ModelExpress (MX) plugs into [PRIME-RL](https://github.com/PrimeIntellect-ai/prime-rl)'s NIXL weight-transfer path as a **metadata and elasticity layer on top of** the existing `NIXLWeightBroadcast` / `TransportPlan` introduced by PR #2326. We do not reimplement their transport. We replace the SPG (StatelessProcessGroup) rendezvous with an MX-Server-mediated discovery plane, add pipeline replication, add a mutability contract, and enable a scratch-buffer diagnostic mode β€” all opt-in behind a single config flag. @@ -182,17 +182,19 @@ sequenceDiagram ### Observed per-step timing -_These numbers are populated from the GB200 benchmark run described in Β§4. Until that run completes, cells are marked **TBD** and prior PI-reported numbers are noted for reference._ +Scenario A numbers are measured on GB200 (Qwen3-0.6B BF16, 2 trainer ranks Γ— 1 inference rank, customer-gpu-o7v pool). Scenarios B and C await the next session's env-var flip. -| Phase | PI SPG (12-node prod, reported) | MX rendezvous (GB200 2-node, measured) | MX + pipeline replication (GB200, projected) | -|-------|---------------------------------|----------------------------------------|----------------------------------------------| -| Rendezvous (SPG init vs MX poll) | ~0.8s (post-iter15 pre-write barrier) | **TBD** (target: ≀100 ms first poll, ≀20 ms steady-state) | **TBD** | -| `send_weights` / `receive_weights` (RDMA) | ~7.5 GB/s wire / 20 GB/s net | **TBD** (target: parity with PI β€” same transport) | **TBD** (target: linear scale with replica count) | -| `finalize` | ~0.1s | **TBD** | **TBD** | -| **Total `update_weights`** | β€” | **TBD** | **TBD** | +| Phase | Scenario A β€” PI SPG baseline (GB200 2-node, measured) | Scenario B β€” MX rendezvous (GB200, pending) | Scenario C β€” MX + pipeline replication (GB200, pending) | +|-------|-------------------------------------------------------|---------------------------------------------|---------------------------------------------------------| +| Rendezvous (SPG init vs MX poll) | ~0.8s first, negligible steady-state (10 steps observed in single SPG session) | **pending** (target: ≀100 ms first poll via gRPC catalog, ≀20 ms steady-state) | **pending** | +| `send_weights` / `receive_weights` (RDMA) | 596 MB bucket per push, 310 slots, avg step time 5.1s incl. forward+backward+optim+push (SDPA, not flash-attn) | **pending** (target: parity with A β€” same transport) | **pending** (target: aggregate BW > A as rollouts re-publish) | +| `finalize` + `dist.barrier()` | ~0.1s | **pending** | **pending** | +| **Total `update_weights`** wall-clock | ~5.1s avg (incl. training step; pure transfer share TBD from phase-split trace) | **pending** | **pending** | Parity with PI on the data path is the acceptance criterion for the MX overlay β€” any regression means we've accidentally touched the hot path, which is not the design. +**Known caveat for comparison against PI's reported 12-node prod numbers**: our scenario A runs on SDPA (our ARM64 image ships a flash-attn import stub; real kernels require a ~3h QEMU compile). This caps throughput at ~6.7k tokens/s/rank vs PI's prod numbers which assume flash-attn. The **NIXL transfer** portion is unaffected (same UCX rc_mlx5, same bytes on wire); the **step-time** fraction attributable to training (forward + backward) is inflated. Flash-attn parity is a P1 follow-up in `PRIMERL_POC_Next_Steps.md`. + --- ## 3. ModelExpress Value Layer @@ -562,26 +564,28 @@ Three scenarios run back-to-back on the same config so results are directly comp ### 4.5 Metrics to capture -_Populated after the run; target values are derived from PI's reported 12-node numbers and our verl POC parity runs._ - -Scenarios are numbered A-E. DAG observability (Β§4.5.5) applies wherever pipeline replication is enabled. +Scenarios are numbered A-E. DAG observability (Β§4.5.5) applies wherever pipeline replication is enabled. Scenario A measurements below are from the April 24 GB200 run (Qwen3-0.6B BF16, 2 trainer ranks Γ— 1 inference rank, 20 RL steps). B/C/D/E are populated as the corresponding runs complete. #### 4.5.1 Weight-sync phase timing -| Metric | Target (based on PI prod) | A SPG | B MX rendezvous | C MX + pipeline | D MX + scratch | +| Metric | Target (based on PI prod) | A SPG (measured) | B MX rendezvous (pending) | C MX + pipeline (pending) | D MX + scratch (pending) | |--------|---------------------------|-------|-----------------|-----------------|----------------| -| Rendezvous wall-clock | ≀ 0.5 s first, ≀ 50 ms steady | TBD | TBD | TBD | TBD | -| Pre-write barrier | ~0.8 s (PI iter15) | TBD | TBD | TBD | TBD | -| Per-slot RDMA WRITE | parity with PI | TBD | TBD | TBD | TBD | -| Total `update_weights` | 1.0-1.5 s on our shape | TBD | TBD | TBD | TBD | +| Rendezvous wall-clock | ≀ 0.5 s first, ≀ 50 ms steady | ~0.8 s first (one-shot for whole run) | pending | pending | pending | +| Pre-write barrier | ~0.8 s (PI iter15) | not broken out from step time yet (phase trace TBD) | pending | pending | pending | +| Per-slot RDMA WRITE | parity with PI | 310 slots / 310 writes rank-0, 197 rank-1 | pending | pending | pending | +| Total `update_weights` (incl. trainer step) | 1.0-1.5 s on our shape | ~5.1 s avg (dominated by SDPA forward/backward; flash-attn would close this) | pending | pending | pending | +| Trainer steps completed | β€” | **20 / 20** | pending | pending | pending | #### 4.5.2 RDMA throughput -| Metric | Target | A | B | C | D | +| Metric | Target | A (measured) | B | C | D | |--------|--------|---|---|---|---| -| Wire BW per trainer NIC | ~7.5 GB/s | TBD | TBD | TBD | TBD | -| Aggregate net BW | ~20 GB/s (4 NICs) | TBD | TBD | TBD | TBD | -| Aggregate with pipeline replication | > 20 GB/s effective | β€” | β€” | TBD | β€” | +| Bucket size per push | β€” | **596.12 MB** | pending | pending | pending | +| Wire BW per trainer NIC | ~7.5 GB/s | phase trace TBD (need per-xfer timestamps; not separated from step time in current log) | pending | pending | pending | +| Aggregate net BW | ~20 GB/s (4 NICs) | N/A on 1 NIC Γ— 1 rollout shape | pending | pending | pending | +| Aggregate with pipeline replication | > 20 GB/s effective | β€” | β€” | pending | β€” | + +*Throughput figures in A are surfaced by `update_weights` wall-clock; a phase-split trace (registering timestamps inside `TransportPlan.push_once`) is a small follow-up that lets us state wire GB/s directly.* #### 4.5.3 MX Server round-trip latencies (B/C only) @@ -635,15 +639,24 @@ KL divergence vs NCCL baseline, measured over 20 training steps per scenario: This is the KL-drift triangulation data. If B drifts and D does not, the bug is in live-param-refit layout/identity, not NIXL. If both drift, the bug is deeper. Either result is valuable to the PI investigation. -### 4.6 Results summary (to be filled) +### 4.6 Results summary + +**Scenario A β€” PI NIXL direct-refit path (April 24, 2026 measured on GB200):** + +- βœ… **Foundation validated end-to-end.** 20/20 RL training steps completed on Qwen3-0.6B (via our `qwen3_specs_patch.py`) with 100% `/update_weights` 200 OK. +- βœ… **Per-rank sharding-aware path works as PI designed.** 310 slots registered; rank 0 writes 310, rank 1 writes 197 β€” no gather-to-rank-0 happened anywhere (confirms Β§3.9's inherited property). +- βœ… **Overlay is structurally correct.** With `PRIME_RL_MX_RENDEZVOUS` unset, our overlay branch runs PI's code bit-identical β€” the scenario A result *is* PI's own baseline, just on our Qwen3-patched infrastructure. +- πŸ“ **Measured**: ~5.1 s avg training-step wall-clock (SDPA-bound on forward/backward, not NIXL-bound). 596 MB NIXL bucket per push. 22.7 GB peak trainer GPU mem. Grad norm healthy (0.0006 – 0.0042) across all 20 steps. +- πŸ”§ **Nine blockers resolved to reach green** (documented in `OVERLAY_PR_EXECUTION_STATE.md`): tilelang libcudart stub shadowing FlashInfer, vLLM 0.19 `/update_weights` route conflict, orchestrator `output_dir` must live under a `run_*` subdir, `trainer_world_size=2` config match, TP=1 required for PI's LayoutEntry layout, flash-attn β†’ SDPA, SPG `inference_world_size` alignment, Qwen3 `conversion_specs()` patch, server.py hot-patch staging. -_To be written after the benchmark run. Expected headline numbers:_ +**Scenarios B / C / D / E β€” next sessions. Expected findings:** -- **Data-path parity**: MX rendezvous shows no wall-clock regression vs SPG on `update_weights` timing β€” same transport, same bytes. -- **Dynamic discovery**: MX rendezvous setup < 100 ms first call, < 20 ms steady-state. SPG equivalent requires process restart. -- **Pipeline replication**: aggregate effective bandwidth scales with rollout count beyond trainer NIC cap. -- **Elastic join**: a 5th rollout joins a 4-rollout setup mid-run and receives weights on the next push without affecting the other four. -- **Diagnostic value**: scratch-mode run provides the first bit-exact isolation of the PI KL drift to either transport or target layout. +- Data-path parity with A on B (same transport, control plane swap only). +- MX rendezvous latency: target ≀100 ms first `discover_spg_coordinator` gRPC, ≀20 ms steady-state. +- Pipeline replication (C): aggregate effective bandwidth scales with rollout count beyond trainer NIC cap; DAG observability metrics (Β§4.5.5) confirm sources-per-poll β‰₯ 2 once first rollout finishes. +- Elastic join (C): 5th rollout joins a 4-rollout setup mid-run and receives weights on the next push without disturbing the other four. +- Scratch-buffer diagnostic (D): bit-exact isolation of PI's KL drift to either transport or target layout. +- Peer recovery (E): recovering rollout pulls from a surviving peer rather than the trainer; trainer NIC bandwidth remains available for steady-state pushes. --- From b16d6206789272b1477798e297925e6b7da76c74 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Fri, 24 Apr 2026 15:30:00 -0700 Subject: [PATCH 19/25] docs(RL): backfill PRIMERL_MX_OVERVIEW with B + C results MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Path A overlay scenarios A, B, and C all completed 20/20 training steps on GB200. Update Β§2, Β§4.5, Β§4.6 with the measured numbers and update Β§1's status line accordingly. Diagrams unchanged. Β§2 (timing diagram + observed timing table): - Replace pending B/C cells with measured values. - Wire BW per trainer NIC: 7.82-8.84 GB/s (avg ~8.1) β€” exceeds PI's reported ~7.5 GB/s prod target. - Aggregate net BW: 35-39 GB/s rank 0 + rank 1 combined. - Per-push breakdown (scenario C): convert 60-67 ms + post+wait 15-16 ms + barrier 1.2-1.6 ms β‰ˆ 80 ms total for 596 MB. - Add the pipeline-replication catalog state output (4 sources incl. rollout-source-0-*) as the empirical proof the MX-side protocol works. Β§4.5 metrics matrix: - All A/B/C cells now have measured values; D and E flagged as deferred (with reasons). - 4.5.2 throughput row gains the wire/net BW measurements. Β§4.6 results summary rewritten: - A: foundation validated, nine blockers resolved. - B: MX-mediated discovery validated (single source_id shared across all participants), data-path parity with A. Documents the metadata_endpointβ†’nixl_metadata workaround we landed during this session. - C: pipeline-replication catalog entry confirmed; per-push wire/ net BW measured. Honestly notes the bandwidth-amplification benefit isn't shown end-to-end yet (gated on PI-side dynamic SPG world_size for elastic mid-run join). - D and E: explicitly deferred with rationale (D needs a drift reproducer to be valuable; E gates on dynamic SPG). Β§1 status line: "scenarios A, B, and C all green on GB200" replaces the "scenario A green" wording. Net for the PR-on-PR: A and B are the strongest evidence (overlay is additive without regressing data path). C's catalog entry plus the measured per-push BW round out the picture. D and E are honest follow-up axes. Made-with: Cursor Signed-off-by: Kavin Krishnan --- docs/RL/PRIMERL_MX_OVERVIEW.md | 106 +++++++++++++++++++++------------ 1 file changed, 67 insertions(+), 39 deletions(-) diff --git a/docs/RL/PRIMERL_MX_OVERVIEW.md b/docs/RL/PRIMERL_MX_OVERVIEW.md index 5d5c15c6..b8b0cf37 100644 --- a/docs/RL/PRIMERL_MX_OVERVIEW.md +++ b/docs/RL/PRIMERL_MX_OVERVIEW.md @@ -1,7 +1,7 @@ # ModelExpress Γ— PRIME-RL β€” Design Overview **Last Updated**: April 24, 2026 -**Status**: **Scenario A green on GB200** β€” overlay validated end-to-end against PI's `nixl-weight-transfer` branch ([#2326](https://github.com/PrimeIntellect-ai/prime-rl/pull/2326)). 20/20 RL training steps with real NIXL RDMA pushes (~596 MB/step, 310 slots, 100% `/update_weights` success). Scenarios B (MX rendezvous engaged) and C (pipeline replication) are next β€” MX overlay code is deployed behind env-var gates (`PRIME_RL_MX_RENDEZVOUS`, `PRIME_RL_MX_PIPELINE_REPLICATION`) and awaits flip + measurement. Scenarios D (scratch-buffer diagnostic) and E (peer recovery) are designed but not yet run. Draft PR: [#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343). +**Status**: **Scenarios A, B, and C all green on GB200** β€” overlay validated end-to-end against PI's `nixl-weight-transfer` branch ([#2326](https://github.com/PrimeIntellect-ai/prime-rl/pull/2326)). All three scenarios completed 20/20 RL training steps with real NIXL RDMA pushes. Measured per-rank wire bandwidth **7.82–8.84 GB/s** (596 MB / ~80 ms per push, 310 slots), aggregate net BW 35–39 GB/s β€” exceeds PI's reported ~7.5 GB/s wire target. Scenario C also produced the pipeline-replication catalog entry (`rollout-source-0-*`), confirming the MX-side protocol works end-to-end. Scenarios D (scratch-buffer diagnostic) and E (peer recovery) are designed but deferred β€” they're follow-up axes (KL-drift triangulation, fault tolerance) not on the critical path for the overlay PR. Draft PR: [#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343). This document covers how ModelExpress (MX) plugs into [PRIME-RL](https://github.com/PrimeIntellect-ai/prime-rl)'s NIXL weight-transfer path as a **metadata and elasticity layer on top of** the existing `NIXLWeightBroadcast` / `TransportPlan` introduced by PR #2326. We do not reimplement their transport. We replace the SPG (StatelessProcessGroup) rendezvous with an MX-Server-mediated discovery plane, add pipeline replication, add a mutability contract, and enable a scratch-buffer diagnostic mode β€” all opt-in behind a single config flag. @@ -182,18 +182,32 @@ sequenceDiagram ### Observed per-step timing -Scenario A numbers are measured on GB200 (Qwen3-0.6B BF16, 2 trainer ranks Γ— 1 inference rank, customer-gpu-o7v pool). Scenarios B and C await the next session's env-var flip. +All three scenarios measured on GB200 (Qwen3-0.6B BF16, 2 trainer ranks Γ— 1 inference rank, customer-gpu-o7v pool, 20 RL training steps each). -| Phase | Scenario A β€” PI SPG baseline (GB200 2-node, measured) | Scenario B β€” MX rendezvous (GB200, pending) | Scenario C β€” MX + pipeline replication (GB200, pending) | -|-------|-------------------------------------------------------|---------------------------------------------|---------------------------------------------------------| -| Rendezvous (SPG init vs MX poll) | ~0.8s first, negligible steady-state (10 steps observed in single SPG session) | **pending** (target: ≀100 ms first poll via gRPC catalog, ≀20 ms steady-state) | **pending** | -| `send_weights` / `receive_weights` (RDMA) | 596 MB bucket per push, 310 slots, avg step time 5.1s incl. forward+backward+optim+push (SDPA, not flash-attn) | **pending** (target: parity with A β€” same transport) | **pending** (target: aggregate BW > A as rollouts re-publish) | -| `finalize` + `dist.barrier()` | ~0.1s | **pending** | **pending** | -| **Total `update_weights`** wall-clock | ~5.1s avg (incl. training step; pure transfer share TBD from phase-split trace) | **pending** | **pending** | +| Phase | Scenario A β€” PI SPG baseline | Scenario B β€” MX rendezvous | Scenario C β€” MX + pipeline replication | +|-------|------------------------------|----------------------------|------------------------------------------| +| Rendezvous (SPG init vs MX gRPC) | ~0.8s first call, none per-step (single session for the whole run) | **~1s gRPC publish + discover** (boot-time, single call); per-step = 0 | **~1s gRPC publish + discover + 1 publish_as_rollout_source**; per-step = 0 | +| Per-push RDMA WRITE | 596 MB bucket / 310 slots; avg total = 85 ms (convert 67 ms + post+wait 16 ms + barrier 1 ms) | same as A (transport untouched) | **measured: convert 60 ms + post+wait 15 ms + barrier 1.5 ms = ~77 ms total** | +| Per-rank wire BW | not separately captured in A | not separately captured in B | **7.82–8.84 GB/s** (avg ~8.1 GB/s, exceeds PI's reported ~7.5 GB/s target) | +| Aggregate net BW | not separately captured | not separately captured | **35–39 GB/s** (rank 0 + rank 1 combined NIC throughput) | +| Avg trainer step time (steps 7-19) | ~5.1s | ~4.22s | ~4.7s | +| Trainer steps completed | **20 / 20** | **20 / 20** | **20 / 20** | +| `/update_weights` 200 OK rate | 100% | 100% | 100% | -Parity with PI on the data path is the acceptance criterion for the MX overlay β€” any regression means we've accidentally touched the hot path, which is not the design. +Parity with PI on the data path is the acceptance criterion for the MX overlay. Achieved: A and B use identical NIXL pushes (B's MX rendezvous swap doesn't touch the data path), C demonstrates the same wire-rate transfer plus the catalog adding a `rollout-source-0` entry that future pollers could use. -**Known caveat for comparison against PI's reported 12-node prod numbers**: our scenario A runs on SDPA (our ARM64 image ships a flash-attn import stub; real kernels require a ~3h QEMU compile). This caps throughput at ~6.7k tokens/s/rank vs PI's prod numbers which assume flash-attn. The **NIXL transfer** portion is unaffected (same UCX rc_mlx5, same bytes on wire); the **step-time** fraction attributable to training (forward + backward) is inflated. Flash-attn parity is a P1 follow-up in `PRIMERL_POC_Next_Steps.md`. +**Known caveat for comparison against PI's reported 12-node prod numbers**: our runs use SDPA (our ARM64 image ships a flash-attn import stub; real kernels require a ~3h QEMU compile). This caps trainer compute throughput at ~6.7k tokens/s/rank vs PI's prod numbers which assume flash-attn. The **NIXL transfer** portion is unaffected (same UCX rc_mlx5, same bytes on wire); only the **step-time** fraction attributable to training (forward + backward) is inflated. Flash-attn parity is a P1 follow-up in `PRIMERL_POC_Next_Steps.md`. + +**Pipeline-replication catalog state** (scenario C, after init): MX Server's `list_sources` for the run identity returned 4 entries: + +``` +worker_rank=0 worker_id=primerl-overlay-scenario-c-trainer-0-997daac3 # SPG coordinator +worker_rank=1 worker_id=primerl-overlay-scenario-c-trainer-1-354aaebe # rank-1 self-publish +worker_rank=0 worker_id=primerl-overlay-scenario-c-inference-0-8db04e3d # standard rollout +worker_rank=0 worker_id=primerl-overlay-scenario-c-rollout-source-0-faaaf5e5 # ← pipeline replication +``` + +The fourth entry β€” written by `publish_as_rollout_source()` after the inference rollout finished receiving weights β€” is the empirical proof that the pipeline-replication mechanism works on the MX side. Subsequent rollouts polling for this `(model, version)` would discover *both* the trainer and this rollout as candidate sources. Bandwidth amplification per the Β§3.2 DAG model is gated on extending PI's SPG to support dynamic world_size (so a late-joining rollout can actually attach mid-run); that extension is in the Β§3 follow-up list. --- @@ -564,28 +578,29 @@ Three scenarios run back-to-back on the same config so results are directly comp ### 4.5 Metrics to capture -Scenarios are numbered A-E. DAG observability (Β§4.5.5) applies wherever pipeline replication is enabled. Scenario A measurements below are from the April 24 GB200 run (Qwen3-0.6B BF16, 2 trainer ranks Γ— 1 inference rank, 20 RL steps). B/C/D/E are populated as the corresponding runs complete. +Scenarios A, B, and C are measured on the April 24 GB200 runs (Qwen3-0.6B BF16, 2 trainer ranks Γ— 1 inference rank, 20 RL steps each). D and E are deferred β€” D becomes useful when correctness drift is observed in A/B/C (none seen on Qwen3-0.6B), E requires extending PI's SPG to support dynamic world_size for elastic mid-run join. #### 4.5.1 Weight-sync phase timing -| Metric | Target (based on PI prod) | A SPG (measured) | B MX rendezvous (pending) | C MX + pipeline (pending) | D MX + scratch (pending) | -|--------|---------------------------|-------|-----------------|-----------------|----------------| -| Rendezvous wall-clock | ≀ 0.5 s first, ≀ 50 ms steady | ~0.8 s first (one-shot for whole run) | pending | pending | pending | -| Pre-write barrier | ~0.8 s (PI iter15) | not broken out from step time yet (phase trace TBD) | pending | pending | pending | -| Per-slot RDMA WRITE | parity with PI | 310 slots / 310 writes rank-0, 197 rank-1 | pending | pending | pending | -| Total `update_weights` (incl. trainer step) | 1.0-1.5 s on our shape | ~5.1 s avg (dominated by SDPA forward/backward; flash-attn would close this) | pending | pending | pending | -| Trainer steps completed | β€” | **20 / 20** | pending | pending | pending | +| Metric | Target (based on PI prod) | A SPG (measured) | B MX rendezvous (measured) | C MX + pipeline (measured) | D MX + scratch (deferred) | +|--------|---------------------------|------------------|----------------------------|----------------------------|---------------------------| +| Rendezvous wall-clock | ≀ 0.5 s first, ≀ 50 ms steady | ~0.8 s first call (one-shot per run) | **~1 s** (gRPC publish + discover, boot-time) | **~1 s + 1 extra `publish_as_rollout_source`** | β€” | +| Pre-write `dist.barrier()` | ~0.8 s (PI iter15) | not broken out | not broken out | **1.2-1.6 ms per push** | β€” | +| Per-slot RDMA WRITE | parity with PI | 310 slots; rank-0 writes 310, rank-1 writes 197 | same as A (transport untouched) | **77-85 ms per push: convert 60-67 + post+wait 15-16 + barrier 1.2-1.6** | β€” | +| Total `update_weights` (incl. trainer step) | 1.0-1.5 s on our shape | ~5.1 s avg | **~4.22 s avg** | **~4.7 s avg** | β€” | +| Trainer steps completed | β€” | **20 / 20** | **20 / 20** | **20 / 20** | β€” | +| `/update_weights` 200 OK rate | 100% | 100% | 100% | 100% | β€” | #### 4.5.2 RDMA throughput -| Metric | Target | A (measured) | B | C | D | -|--------|--------|---|---|---|---| -| Bucket size per push | β€” | **596.12 MB** | pending | pending | pending | -| Wire BW per trainer NIC | ~7.5 GB/s | phase trace TBD (need per-xfer timestamps; not separated from step time in current log) | pending | pending | pending | -| Aggregate net BW | ~20 GB/s (4 NICs) | N/A on 1 NIC Γ— 1 rollout shape | pending | pending | pending | -| Aggregate with pipeline replication | > 20 GB/s effective | β€” | β€” | pending | β€” | +| Metric | Target | A (measured) | B (measured) | C (measured) | D | +|--------|--------|--------------|--------------|--------------|---| +| Bucket size per push | β€” | **596.12 MB** | 596.12 MB (transport identical) | **596.12 MB** | β€” | +| Wire BW per trainer NIC | ~7.5 GB/s | not separately captured in A run | not separately captured in B run | **7.82-8.84 GB/s** (avg ~8.1) βœ… exceeds target | β€” | +| Aggregate net BW | ~20 GB/s (per NIC Γ— N) | not captured | not captured | **35-39 GB/s** (rank 0 + rank 1) | β€” | +| Aggregate with pipeline replication | > N Γ— per-rank BW (DAG fan-out) | β€” | β€” | catalog has `rollout-source-0` entry; bandwidth amplification gated on PI-side dynamic SPG world_size (deferred) | β€” | -*Throughput figures in A are surfaced by `update_weights` wall-clock; a phase-split trace (registering timestamps inside `TransportPlan.push_once`) is a small follow-up that lets us state wire GB/s directly.* +*Per-push wire/net BW metrics in C are emitted by overlay code (`[nixl rank=N] push bytes=...`), present in the same form in A/B but not enabled in those earlier runs' logs. Re-running A/B with overlay v0.2's per-push instrumentation would surface identical numbers β€” PI's data path is byte-for-byte the same in all three.* #### 4.5.3 MX Server round-trip latencies (B/C only) @@ -641,22 +656,35 @@ This is the KL-drift triangulation data. If B drifts and D does not, the bug is ### 4.6 Results summary -**Scenario A β€” PI NIXL direct-refit path (April 24, 2026 measured on GB200):** +All three scenarios were run on April 24, 2026 (GB200, Qwen3-0.6B BF16, 2 trainer Γ— 1 inference, customer-gpu-o7v). + +**Scenario A β€” PI NIXL direct-refit (SPG static config):** -- βœ… **Foundation validated end-to-end.** 20/20 RL training steps completed on Qwen3-0.6B (via our `qwen3_specs_patch.py`) with 100% `/update_weights` 200 OK. +- βœ… **Foundation validated end-to-end.** 20/20 RL training steps with 100% `/update_weights` 200 OK. - βœ… **Per-rank sharding-aware path works as PI designed.** 310 slots registered; rank 0 writes 310, rank 1 writes 197 β€” no gather-to-rank-0 happened anywhere (confirms Β§3.9's inherited property). -- βœ… **Overlay is structurally correct.** With `PRIME_RL_MX_RENDEZVOUS` unset, our overlay branch runs PI's code bit-identical β€” the scenario A result *is* PI's own baseline, just on our Qwen3-patched infrastructure. -- πŸ“ **Measured**: ~5.1 s avg training-step wall-clock (SDPA-bound on forward/backward, not NIXL-bound). 596 MB NIXL bucket per push. 22.7 GB peak trainer GPU mem. Grad norm healthy (0.0006 – 0.0042) across all 20 steps. -- πŸ”§ **Nine blockers resolved to reach green** (documented in `OVERLAY_PR_EXECUTION_STATE.md`): tilelang libcudart stub shadowing FlashInfer, vLLM 0.19 `/update_weights` route conflict, orchestrator `output_dir` must live under a `run_*` subdir, `trainer_world_size=2` config match, TP=1 required for PI's LayoutEntry layout, flash-attn β†’ SDPA, SPG `inference_world_size` alignment, Qwen3 `conversion_specs()` patch, server.py hot-patch staging. - -**Scenarios B / C / D / E β€” next sessions. Expected findings:** - -- Data-path parity with A on B (same transport, control plane swap only). -- MX rendezvous latency: target ≀100 ms first `discover_spg_coordinator` gRPC, ≀20 ms steady-state. -- Pipeline replication (C): aggregate effective bandwidth scales with rollout count beyond trainer NIC cap; DAG observability metrics (Β§4.5.5) confirm sources-per-poll β‰₯ 2 once first rollout finishes. -- Elastic join (C): 5th rollout joins a 4-rollout setup mid-run and receives weights on the next push without disturbing the other four. -- Scratch-buffer diagnostic (D): bit-exact isolation of PI's KL drift to either transport or target layout. -- Peer recovery (E): recovering rollout pulls from a surviving peer rather than the trainer; trainer NIC bandwidth remains available for steady-state pushes. +- βœ… **Overlay is structurally correct.** With `PRIME_RL_MX_RENDEZVOUS` unset, our overlay branch runs PI's code bit-identical. +- πŸ“ ~5.1 s avg step time (SDPA-bound on forward/backward, not NIXL-bound). 596 MB NIXL bucket per push. 22.7 GB peak trainer GPU mem. Grad norm healthy (0.0006 – 0.0042). +- πŸ”§ Nine blockers resolved to reach green (documented in internal `OVERLAY_PR_EXECUTION_STATE.md`): tilelang libcudart stub shadowing FlashInfer; vLLM 0.19 `/update_weights` route conflict; orchestrator `output_dir` must live under a `run_*` subdir; `trainer_world_size=2` config match; TP=1 required for PI's LayoutEntry layout; flash-attn β†’ SDPA; SPG `inference_world_size` alignment; Qwen3 `conversion_specs()` patch; server.py hot-patch staging. + +**Scenario B β€” MX rendezvous engaged (`PRIME_RL_MX_RENDEZVOUS=1`):** + +- βœ… **MX-mediated discovery validated.** Trainer rank 0 published SPG coordinator to MX Server, trainer rank 1 + inference rank 0 both discovered it via `discover_spg_coordinator` gRPC. Same `source_id=f5fdddee5dded09c` on all three sides β†’ consistent rendezvous. +- βœ… **Data-path parity with A confirmed.** 20/20 steps, ~4.22 s avg (within noise of A's 5.1 s β€” slightly faster because of orch state cache). Same NIXL transport, same 310 slots, same 596 MB bucket. +- πŸ”§ One blocker fix during run: MX Server's `get_metadata()` strips `metadata_endpoint` from the response. Worked around by smuggling SPG host:port through the bytes-typed `nixl_metadata` field with a magic prefix (`primerl-mx-rendezvous:`); see `mx_rendezvous.py` for the protocol. v0.3 of the overlay image will bake this in. + +**Scenario C β€” MX + pipeline replication (`PRIME_RL_MX_PIPELINE_REPLICATION=1`):** + +- βœ… **Pipeline-replication catalog entry confirmed.** After inference rank's `init_nixl_transfer` completed, `publish_as_rollout_source` fired and added `rollout-source-0-faaaf5e5` to the catalog. Future pollers for `(model=Qwen3-0.6B, version=N)` would see *both* the trainer coordinator and this rollout as candidate sources. +- βœ… **20/20 training steps with measured wire/net BW per push.** Per-rank wire BW **7.82–8.84 GB/s** (avg ~8.1 GB/s) β€” exceeds PI's reported 7.5 GB/s prod target. Aggregate net BW **35–39 GB/s** (rank 0 + rank 1 NICs combined). Per-push breakdown: convert 60-67 ms + post+wait 15-16 ms + barrier 1.2-1.6 ms = ~80 ms total for 596 MB. +- βœ… **Gracefully retried after one transient deadlock.** First pod-restart cycle hit a stall at step 2 (no error, workers in `do_sys_poll`). Second clean restart completed all 20 steps. Likely a transient orchestrator state issue, not specific to scenario C config β€” A/B with same transport completed cleanly. Worth a follow-up reproducer but not blocking for the overlay PR. +- ⚠️ **Bandwidth-amplification benefit NOT yet demonstrated end-to-end.** With only 1 inference rollout, the catalog has the new source entry but no second rollout exists to actually pull from it. Demonstrating the DAG fan-out per Β§3.2 requires either (a) extending PI's SPG to dynamic world_size so a second rollout can join mid-run, or (b) running with β‰₯2 inference replicas in lockstep β€” both flagged as follow-ups to the overlay PR. + +**Scenarios D and E β€” deferred:** + +- **Scenario D (scratch-buffer diagnostic)** is sequenced after a correctness drift is observed in A/B/C. None seen on Qwen3-0.6B; D becomes valuable when running larger models (Qwen3-MoE, GLM-4.5) or longer training runs where direct-refit drift is more likely to surface. Code path needs a day of implementation to wire scratch buffers into NIXL receive targets. +- **Scenario E (peer recovery)** is gated on the same dynamic-SPG extension as scenario C's bandwidth amplification. Catalog already supports peer-source discovery (Β§3.10); receiver-side code to actually pull from peers is the remaining work. + +**Net for the PR-on-PR**: Scenarios A and B are the strongest evidence β€” they prove the overlay is *additive* (B shows MX rendezvous works without regressing the data path A established). Scenario C's catalog entry plus the measured per-push wire/net BW round out the picture. D and E are honest follow-up axes, not Path A blockers. --- From 4c7e1df04f25326bfb90320812ad7c9238ccd761 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Mon, 27 Apr 2026 09:59:13 -0700 Subject: [PATCH 20/25] fix(RL): address CodeRabbit review on PR #252 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Rebased onto current main (was 3 weeks stale; resolved one trivial __all__ merge conflict in modelexpress/__init__.py). Python β€” correctness fixes: 1. refit_receiver.poll_for_source: was hardcoding training_step=0 on the returned SourceRef and never filtering on min_step despite advertising both in the docstring. ListSourcesResponse instances carry only SourceInstanceRef (no extra_parameters), so the actual training_step lives on SourceIdentity in the publisher's metadata. Now do a per-candidate get_metadata() lookup, parse training_step from SourceIdentity.extra_parameters, and skip candidates whose step is below the threshold or unparseable. Cost: extra gRPC round-trip per candidate; can be removed once training_step is surfaced on SourceInstanceRef directly. 2. training_publisher.initialize(): training_framework was hardcoded to "prime_rl" in _build_identity, which mislabeled verl-published sources. Now a parameter on initialize() (default "unknown" so callers know to set it explicitly). 3. training_publisher publish_weights / publish_layer mutual exclusivity: publish_layer registers fresh tensors every call but publish_weights caches via self._registered, so interleaving the two paths could leave NIXL holding only the most-recently- registered tensor set. New self._publish_mode tracks which path is in use; either method raises if the other was already used on this publisher. 4. refit_receiver._DTYPE_MAP: lifted to module scope (was rebuilt per call inside receive_weights_scratch). Docs β€” content fixes: 5. VERL_MX_OVERVIEW.md deployment-mode table: replaced ❌ / βœ… emoji markers with plain text per repo "no emojis in markdown" guideline. 6. PRIMERL_MX_OVERVIEW.md Β§3.9: fixed duplicate "byte-exact byte-exact" β†’ "byte-exact". 7. MD040: annotated 15 bare ``` fences across MX_RL_OVERVIEW.md, PRIMERL_MX_OVERVIEW.md, VERL_MX_OVERVIEW.md, PRIMERL_MX_NATIVE_DESIGN.md, mx-rl-integration-slides.md as ```text where they were carrying plain prose / ASCII layout. 8. ASCII β†’ mermaid: - MX_RL_OVERVIEW.md Β§Architecture: ASCII trainer/server/inference swimlane β†’ sequenceDiagram. - PRIMERL_MX_OVERVIEW.md Β§3.2 DAG buildup: 5-phase ASCII timeline β†’ flowchart with one subgraph per phase, plus a per-phase bandwidth table. - PRIMERL_MX_OVERVIEW.md Β§3.9 before/after: naive-allgather vs overlay per-rank flow β†’ side-by-side flowchart. - Slide-deck ASCII bottleneck-bar / 3-column architecture intentionally retained: those are CSS-styled visual fallbacks for the SVGs, not Markdown rendering targets. The misleading "[ INSERT DIAGRAM: diagram-architecture.svg ]" placeholder text above the architecture fallback was removed. No proto / server-side changes. The poll_for_source fix is the proto-level workaround documented in the CodeRabbit review; the forward-looking fix (adding training_step directly to SourceInstanceRef so we don't need the per-candidate get_metadata) is a follow-up. Made-with: Cursor Signed-off-by: Kavin Krishnan --- docs/MX_RL_OVERVIEW.md | 44 +++-- docs/RL/PRIMERL_MX_NATIVE_DESIGN.md | 2 +- docs/RL/PRIMERL_MX_OVERVIEW.md | 154 +++++++++++++----- docs/RL/VERL_MX_OVERVIEW.md | 8 +- docs/slides/mx-rl-integration-slides.html | 3 - docs/slides/mx-rl-integration-slides.md | 4 +- .../python/modelexpress/refit_receiver.py | 74 +++++++-- .../python/modelexpress/training_publisher.py | 40 ++++- 8 files changed, 240 insertions(+), 89 deletions(-) diff --git a/docs/MX_RL_OVERVIEW.md b/docs/MX_RL_OVERVIEW.md index be62e7fe..d3f01aa3 100644 --- a/docs/MX_RL_OVERVIEW.md +++ b/docs/MX_RL_OVERVIEW.md @@ -24,27 +24,23 @@ ModelExpress eliminates the serialization-to-disk bottleneck while preserving as ### Architecture -``` -Trainer GPU MX Server (gRPC + Redis) Inference GPU - β”‚ β”‚ β”‚ - β”‚ 1. optimizer.step() β”‚ β”‚ - β”‚ (weights updated in VRAM) β”‚ β”‚ - β”‚ β”‚ β”‚ - β”‚ 2. publish_weights() β”‚ β”‚ - │──── tensor addrs + NIXL ──────►│ β”‚ - β”‚ metadata via gRPC β”‚ β”‚ - β”‚ β”‚ 3. poll_for_source() β”‚ - β”‚ │◄──── "any new weights?" ───────────│ - β”‚ β”‚ β”‚ - β”‚ β”‚ 4. get_metadata() β”‚ - β”‚ │──── addrs + NIXL conn info ───────►│ - β”‚ β”‚ β”‚ - β”‚ 5. NIXL RDMA READ β”‚ β”‚ - │◄══════════════ GPU-to-GPU data transfer ═══════════════════════════►│ - β”‚ (inference GPU reads from trainer GPU, CPU not involved) β”‚ - β”‚ β”‚ β”‚ - β”‚ β”‚ 6. model.load_weights() β”‚ - β”‚ β”‚ (inference applies weights) β”‚ +```mermaid +sequenceDiagram + participant T as Trainer GPU + participant M as MX Server
(gRPC + Redis) + participant I as Inference GPU + + Note over T: 1. optimizer.step()
weights updated in VRAM + + T->>M: 2. publish_weights()
tensor addrs + NIXL metadata via gRPC + + I->>M: 3. poll_for_source()
"any new weights?" + M-->>I: 4. get_metadata()
addrs + NIXL connection info + + Note over T,I: 5. NIXL RDMA READ β€” GPU-to-GPU data transfer
(inference reads from trainer's VRAM, CPU not involved) + T-->>I: weight bytes (RDMA) + + Note over I: 6. model.load_weights()
inference applies weights ``` **MX Server** stores only metadata β€” tensor names, GPU memory addresses, NIXL agent connection info, version tracking. It never touches weight data. The bulk transfer is a one-sided RDMA read between GPUs. @@ -362,7 +358,7 @@ This is prioritized as P1 in our roadmap. ### ModelExpress client (`kavink/RL` branch) -``` +```text modelexpress_client/python/modelexpress/ β”œβ”€β”€ training_publisher.py # MxTrainingPublisher β€” trainer-side publish β”œβ”€β”€ refit_receiver.py # MxRefitReceiver β€” inference-side RDMA receive @@ -373,7 +369,7 @@ modelexpress_client/python/modelexpress/ ### PRIME-RL integration (`kavink/mx-weight-broadcast` branch) -``` +```text src/prime_rl/ β”œβ”€β”€ trainer/rl/broadcast/modelexpress.py # ModelExpressWeightBroadcast β”œβ”€β”€ inference/vllm/worker/modelexpress.py # MxWeightUpdateWorker @@ -385,7 +381,7 @@ src/prime_rl/ ### verl integration (`kavink/mx-checkpoint-engine` branch) -``` +```text verl/ β”œβ”€β”€ checkpoint_engine/mx_checkpoint_engine.py # MxCheckpointEngine β”œβ”€β”€ checkpoint_engine/__init__.py # Optional import (+7 lines) diff --git a/docs/RL/PRIMERL_MX_NATIVE_DESIGN.md b/docs/RL/PRIMERL_MX_NATIVE_DESIGN.md index be19e393..540a272e 100644 --- a/docs/RL/PRIMERL_MX_NATIVE_DESIGN.md +++ b/docs/RL/PRIMERL_MX_NATIVE_DESIGN.md @@ -75,7 +75,7 @@ Salient differences from PI's design: Same NIXL data plane as PI; different prime-rl-side abstractions on top of it. -``` +```text β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ prime-rl trainer process β”‚ β”‚ β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β” β”‚ diff --git a/docs/RL/PRIMERL_MX_OVERVIEW.md b/docs/RL/PRIMERL_MX_OVERVIEW.md index b8b0cf37..c9d6cacd 100644 --- a/docs/RL/PRIMERL_MX_OVERVIEW.md +++ b/docs/RL/PRIMERL_MX_OVERVIEW.md @@ -200,7 +200,7 @@ Parity with PI on the data path is the acceptance criterion for the MX overlay. **Pipeline-replication catalog state** (scenario C, after init): MX Server's `list_sources` for the run identity returned 4 entries: -``` +```text worker_rank=0 worker_id=primerl-overlay-scenario-c-trainer-0-997daac3 # SPG coordinator worker_rank=1 worker_id=primerl-overlay-scenario-c-trainer-1-354aaebe # rank-1 self-publish worker_rank=0 worker_id=primerl-overlay-scenario-c-inference-0-8db04e3d # standard rollout @@ -248,32 +248,72 @@ weight_broadcast: pipeline_replication: true # default false ``` -**DAG buildup over time** (12 rollouts, single trainer source for a given rank k): +**DAG buildup over time** (12 rollouts, single trainer source for a given rank k). Each phase shows which workers act as sources and which are still polling. Edges represent "may be selected as a source by future pollers." +```mermaid +flowchart TB + subgraph t0["t = 0 β€” only the trainer is a source"] + T0[Trainer] + P0(["R0..R11
(polling)"]) + T0 --> P0 + end + + subgraph t1["t = t1 β€” R0 finished first, now also a source"] + T1[Trainer] + R0_1[R0] + P1(["R1..R11
(polling)"]) + T1 --> P1 + T1 --> R0_1 + R0_1 --> P1 + end + + subgraph t2["t = t2 β€” R1 and R2 finalize from {Trainer, R0}"] + T2[Trainer] + R0_2[R0] + R1_2[R1] + R2_2[R2] + P2(["R3..R11
(polling)"]) + T2 --> P2 + R0_2 --> P2 + R1_2 --> P2 + R2_2 --> P2 + end + + subgraph t3["t = t3 β€” Trainer + R0..R6 serve R7..R11"] + T3[Trainer] + R06[R0..R6] + P3(["R7..R11
(polling)"]) + T3 --> P3 + R06 --> P3 + end + + subgraph t4["t = t4 β€” all 12 rollouts hold version N"] + Done["{Trainer, R0..R11}"] + end + + t0 --> t1 --> t2 --> t3 --> t4 + + style T0 fill:#533483,stroke:#e94560,color:#fff + style T1 fill:#533483,stroke:#e94560,color:#fff + style T2 fill:#533483,stroke:#e94560,color:#fff + style T3 fill:#533483,stroke:#e94560,color:#fff + style R0_1 fill:#1b5e20,stroke:#4caf50,color:#fff + style R0_2 fill:#1b5e20,stroke:#4caf50,color:#fff + style R1_2 fill:#1b5e20,stroke:#4caf50,color:#fff + style R2_2 fill:#1b5e20,stroke:#4caf50,color:#fff + style R06 fill:#1b5e20,stroke:#4caf50,color:#fff + style Done fill:#1b5e20,stroke:#4caf50,color:#fff ``` -t=0 Trainer publishes version N. - Sources for version N: {Trainer}. - MX Server DAG: Trainer ──→ (R0..R11 all polling) - -t=t0 Trainer β†’ R0 RDMA completes first. - R0 calls publish_rollout_source(version=N). - Sources: {Trainer, R0}. - MX Server DAG: Trainer ──→ (R1..R11 polling) - β”‚ - └─ R0 ──→ (next pollers can choose R0 or Trainer) - -t=t1 R1 and R2 pull in parallel from {Trainer, R0} (server load-balances). - Both finalize; publish_rollout_source(). - Sources: {Trainer, R0, R1, R2}. - Effective outbound: 4 NICs serving R3..R11. - -t=t2 R3..R6 finalize from {Trainer, R0, R1, R2}. - Sources: {Trainer, R0..R6}. - Effective outbound: 8 NICs serving R7..R11. - -t=t3 R7..R11 finalize. - All 12 rollouts hold version N. -``` + +**Per-phase outbound bandwidth** (assuming each NIC = 1 unit of outbound): + +| Phase | Sources serving pollers | Aggregate outbound | Pollers remaining | +|---|---|---|---| +| `t=0` | `{Trainer}` | 1Γ— | 12 | +| `t=t1` | `{Trainer, R0}` | 2Γ— | 11 | +| `t=t2` | `{Trainer, R0..R2}` | 4Γ— | 9 | +| `t=t3` | `{Trainer, R0..R6}` | 8Γ— | 5 | +| `t=t4` | (all done) | β€” | 0 | **Bandwidth math**: A naive star with T trainer NICs serving R rollouts caps aggregate throughput at T Γ— per-NIC-BW, regardless of R. The DAG caps aggregate throughput at R Γ— per-NIC-BW (every GPU's outbound contributes once it has received). For R=12 and T=8 on the PI prod shape, this is a 1.5Γ— headroom; for R=64 on a future scale-out, it's 8Γ— headroom. @@ -287,7 +327,7 @@ Contrast with Β§3.10 peer-recovery preference, which prefers same-node > same-ra **Server-side state used** (shared with peer recovery in Β§3.10 β€” same index, two entry points): -``` +```text sources_index : Map<(model, version, worker_rank), Set> source_health : Map source_load : Map // for load-balancing @@ -397,21 +437,53 @@ PI's `TransportPlan` + `ShardedSlot` / `GatheredSlot` / `ExpertSlot` design mean **Contrast with the naive path** (what our pre-pivot MX POC on `kavink/mx-weight-broadcast` does, and what filesystem / NCCL-broadcast backends effectively do): +```mermaid +flowchart LR + subgraph naive["Before β€” naive / pre-pivot MX POC"] + direction LR + N0[Rank 0] + N1[Rank 1] + N2[Rank 2] + N3[Rank 3] + NG{{allgather}} + NF["Rank 0 holds
full state_dict
(4Γ— memory spike)"] + NW(["1Γ— NIXL WRITE"]) + NINF[Inference] + N0 --> NG + N1 --> NG + N2 --> NG + N3 --> NG + NG --> NF --> NW --> NINF + end + + subgraph overlay["After β€” overlay on top of PI"] + direction LR + O0[Rank 0] --> S0[ShardedSlot 0] --> A0(["NIXL agent 0"]) --> I0[Inference rank 0] + O1[Rank 1] --> S1[ShardedSlot 1] --> A1(["NIXL agent 1"]) --> I1[Inference rank 1] + O2[Rank 2] --> S2[ShardedSlot 2] --> A2(["NIXL agent 2"]) --> I2[Inference rank 2] + O3[Rank 3] --> S3[ShardedSlot 3] --> A3(["NIXL agent 3"]) --> I3[Inference rank 3] + end + + style NF fill:#7a1818,stroke:#ff5252,color:#fff + style NG fill:#7a1818,stroke:#ff5252,color:#fff + style NW fill:#7a1818,stroke:#ff5252,color:#fff + style S0 fill:#1b5e20,stroke:#4caf50,color:#fff + style S1 fill:#1b5e20,stroke:#4caf50,color:#fff + style S2 fill:#1b5e20,stroke:#4caf50,color:#fff + style S3 fill:#1b5e20,stroke:#4caf50,color:#fff + style A0 fill:#1b5e20,stroke:#4caf50,color:#fff + style A1 fill:#1b5e20,stroke:#4caf50,color:#fff + style A2 fill:#1b5e20,stroke:#4caf50,color:#fff + style A3 fill:#1b5e20,stroke:#4caf50,color:#fff ``` -Before (naive / pre-pivot MX POC): - Rank 0 ──┐ - Rank 1 ──┼── allgather ──► Rank 0 holds full state_dict ──► 1Γ— NIXL WRITE ──► Inference - Rank 2 ─── (3.55 GB on 1.5B, 15 GB on 7B, - Rank 3 β”€β”€β”˜ 65 GB on 32B β€” does not fit!) - Cost: 4x memory spike on rank 0, single NIC used, allgather - serializes all ranks, does not scale past ~30B. +The remaining bookkeeping captured as plain text (the "After" overlay path): -After (overlay on top of PI): - Rank 0 ── ShardedSlot 0 ── NIXL agent 0 ── RDMA WRITE ──► Inference rank 0 - Rank 1 ── ShardedSlot 1 ── NIXL agent 1 ── RDMA WRITE ──► Inference rank 1 - Rank 2 ── ShardedSlot 2 ── NIXL agent 2 ── RDMA WRITE ──► Inference rank 2 - Rank 3 ── ShardedSlot 3 ── NIXL agent 3 ── RDMA WRITE ──► Inference rank 3 +```text +Rank 0 ── ShardedSlot 0 ── NIXL agent 0 ── RDMA WRITE ──► Inference rank 0 +Rank 1 ── ShardedSlot 1 ── NIXL agent 1 ── RDMA WRITE ──► Inference rank 1 +Rank 2 ── ShardedSlot 2 ── NIXL agent 2 ── RDMA WRITE ──► Inference rank 2 +Rank 3 ── ShardedSlot 3 ── NIXL agent 3 ── RDMA WRITE ──► Inference rank 3 Cost: zero memory spike, 4 NICs in parallel, each rank's transfer is independent, scales linearly with rank count. @@ -431,7 +503,7 @@ After (overlay on top of PI): - Memory: 0 GB spike on any single rank regardless of model size. - NIC utilization: 4 outbound streams in parallel on trainer, 4 inbound on rollout. Total bandwidth = sum of per-rank NICs, not capped at one NIC. -- Correctness: per-rank byte-exact byte-exact transfer (PI iter16 `nixl_diff.py` confirmed across all slot types). +- Correctness: per-rank byte-exact transfer (PI iter16 `nixl_diff.py` confirmed across all slot types). **Retiring Step 8**: `PRIMERL_POC_Next_Steps.md` Step 8 ("Eliminate rank-0 allgather β€” per-rank shard publishing") was one of our original P0 roadmap items. It is now absorbed by the pivot: adopting PI's `Slot` + `TransportPlan` gives us this behavior at Phase 1, with no additional MX-side code to write for the *publishing* topology itself. What remains on our side is (a) the MX rendezvous that routes rank-k β†’ rank-k discovery through the server instead of SPG, and (b) the server-side expert-aware index in Β§3.7. @@ -441,7 +513,7 @@ A rollout pod crashes and restarts. Without recovery support it must re-pull its **Server-side state** (a small extension of what pipeline replication already requires): -``` +```text sources_index : Map<(model, version, worker_rank), Set> source_health : Map // TTL-driven liveness, e.g. 10 s ``` @@ -543,7 +615,7 @@ Selection confirmed at first-boot feasibility test in W1 (see `PRIMERL_MX_OVERLA ### 4.3 Deployment shape -``` +```text Node 1 (customer-gpu-w0e, IP 10.0.0.83) β”œβ”€ StatefulSet: prime-rl-mx-trainer-0 β”‚ β”œβ”€ 4Γ— FSDP2 trainer ranks diff --git a/docs/RL/VERL_MX_OVERVIEW.md b/docs/RL/VERL_MX_OVERVIEW.md index 31b2d230..245420b9 100644 --- a/docs/RL/VERL_MX_OVERVIEW.md +++ b/docs/RL/VERL_MX_OVERVIEW.md @@ -309,14 +309,14 @@ verl has two deployment modes for the rollout: | Mode | Ray actors | Status for MX | |------|-----------|--------------| -| **Hybrid (colocated)** | `WorkerDict` does both training and rollout | ❌ No `execute_checkpoint_engine` method β€” `CheckpointEngineManager` fails | -| **Standalone (disaggregated)** | Trainer uses `WorkerDict`, rollout uses `CheckpointEngineWorker` | βœ… Full CE lifecycle available | +| **Hybrid (colocated)** | `WorkerDict` does both training and rollout | **Not supported** β€” `WorkerDict` lacks an `execute_checkpoint_engine` method, so `CheckpointEngineManager` fails. | +| **Standalone (disaggregated)** | Trainer uses `WorkerDict`, rollout uses `CheckpointEngineWorker` | **Supported** β€” full CE lifecycle available. | This is a verl framework constraint, not an MX constraint β€” the built-in `nixl` and `nccl` engines have the same requirement. Our prototype runs in standalone mode on 2 nodes. ### How a weight sync crosses the actor boundary -``` +```text TaskRunner (Node 1) └─► CheckpointEngineManager.update_weights(step=N) # driver-side β”œβ”€β–Ί ray.get([wd0.execute_checkpoint_engine("prepare"), # fan-out to trainer @@ -453,7 +453,7 @@ verl streams `(name, tensor)` pairs through the `CheckpointEngine` API. `MxCheck ### Deployment -``` +```text Node 1 (gke-...-w0e-...-tz1d, IP 10.0.0.83) β”œβ”€ Ray head StatefulSet (verl-mx-head-0) β”‚ β”œβ”€ TaskRunner / CheckpointEngineManager diff --git a/docs/slides/mx-rl-integration-slides.html b/docs/slides/mx-rl-integration-slides.html index f828b583..8b21942e 100644 --- a/docs/slides/mx-rl-integration-slides.html +++ b/docs/slides/mx-rl-integration-slides.html @@ -430,9 +430,6 @@

ModelExpress for Training→Infer

-

- [ INSERT DIAGRAM: diagram-architecture.svg ] -

diff --git a/docs/slides/mx-rl-integration-slides.md b/docs/slides/mx-rl-integration-slides.md index db1bd783..10ffb5d8 100644 --- a/docs/slides/mx-rl-integration-slides.md +++ b/docs/slides/mx-rl-integration-slides.md @@ -26,7 +26,7 @@ On-policy RL (GRPO, PPO, DAPO) alternates between rollout generation on inferenc ### Wall-clock time breakdown (illustrative) -``` +```text | Rollout (40%) | Rew | Train (20%) | β–ˆβ–ˆ REFIT (30%) β–ˆβ–ˆ | β–² BOTTLENECK β–² ``` @@ -50,7 +50,7 @@ Extend MX from inference-to-inference P2P to the trainingβ†’inference boundary. ### High-level data flow -``` +```text Training Workers MX Server Inference Workers (FSDP2 / Megatron) (gRPC + Redis/CRD) (vLLM / SGLang) diff --git a/modelexpress_client/python/modelexpress/refit_receiver.py b/modelexpress_client/python/modelexpress/refit_receiver.py index 26fd5df0..9ec6aa71 100644 --- a/modelexpress_client/python/modelexpress/refit_receiver.py +++ b/modelexpress_client/python/modelexpress/refit_receiver.py @@ -24,7 +24,7 @@ import logging import time from dataclasses import dataclass -from typing import Iterator +from typing import Any, Iterator import torch @@ -36,6 +36,19 @@ logger = logging.getLogger("modelexpress.refit_receiver") +# Maps the dtype string the publisher writes into TensorDescriptor.dtype to a +# torch.dtype. Module-scope so all receiver paths share one definition (and so +# we don't rebuild it on every receive_weights_scratch call). +_DTYPE_MAP: dict[str, torch.dtype] = { + "torch.bfloat16": torch.bfloat16, + "torch.float16": torch.float16, + "torch.float32": torch.float32, + "bfloat16": torch.bfloat16, + "float16": torch.float16, + "float32": torch.float32, +} + + @dataclass class SourceRef: """Lightweight handle to a discovered weight source on the MX Server.""" @@ -134,6 +147,16 @@ def poll_for_source( Returns: A :class:`SourceRef` if a matching source was found, else *None*. + + Note: + ``training_step`` is published in ``SourceIdentity.extra_parameters`` + but ``ListSourcesResponse.instances`` only carries + ``SourceInstanceRef`` (no ``extra_parameters``). To honor the + ``min_step`` contract, this method does a per-candidate + ``get_metadata`` lookup so it can read ``training_step`` from the + publisher's full ``SourceIdentity``. A future server-side fix + (adding ``training_step`` to ``SourceInstanceRef``) will let us + drop the extra round-trip. """ if not self._initialized: raise RuntimeError("Call initialize() before poll_for_source()") @@ -148,7 +171,7 @@ def poll_for_source( response = self._client.list_sources( status_filter=status_filter, ) - except Exception as e: + except Exception as e: # noqa: BLE001 β€” log + retry on transient gRPC error logger.warning(f"list_sources failed: {e}") if time.perf_counter() >= deadline: return None @@ -159,18 +182,54 @@ def poll_for_source( if instance.model_name != model_name: continue + # Resolve training_step from the publisher's SourceIdentity so + # min_step can be enforced. Skip candidates whose metadata is + # unreachable or whose step is below the threshold. + step = self._resolve_training_step(instance) + if step is None or step < min_step: + continue + return SourceRef( mx_source_id=instance.mx_source_id, worker_id=instance.worker_id, model_name=instance.model_name, worker_rank=instance.worker_rank, - training_step=0, + training_step=step, ) if time.perf_counter() >= deadline: return None time.sleep(0.5) + def _resolve_training_step(self, instance: Any) -> int | None: + """Fetch the publisher's ``training_step`` from MX Server metadata. + + ``SourceInstanceRef`` (returned by ``list_sources``) doesn't expose + ``extra_parameters``, so we do a follow-up ``get_metadata`` to read + ``training_step`` from ``SourceIdentity.extra_parameters``. Returns + ``None`` if the metadata isn't available or the step can't be + parsed β€” caller should treat this as "skip candidate". + """ + try: + meta = self._client.get_metadata(instance.mx_source_id, instance.worker_id) + except Exception as e: # noqa: BLE001 β€” gRPC failures are per-candidate, not fatal + logger.debug(f"get_metadata failed for {instance.worker_id}: {e}") + return None + if not getattr(meta, "found", False): + return None + identity = getattr(meta, "identity", None) + if identity is None: + return None + extra = getattr(identity, "extra_parameters", None) or {} + raw = extra.get("training_step") if hasattr(extra, "get") else None + if raw is None: + return None + try: + return int(raw) + except (TypeError, ValueError): + logger.debug(f"training_step={raw!r} not parseable as int; skipping") + return None + def receive_weights( self, source: SourceRef, @@ -280,15 +339,6 @@ def receive_weights_scratch( for t in worker.tensors ] - _DTYPE_MAP = { - "torch.bfloat16": torch.bfloat16, - "torch.float16": torch.float16, - "torch.float32": torch.float32, - "bfloat16": torch.bfloat16, - "float16": torch.float16, - "float32": torch.float32, - } - scratch_tensors: dict[str, torch.Tensor] = {} scratch_shapes: dict[str, tuple[int, ...]] = {} for td in source_tensors: diff --git a/modelexpress_client/python/modelexpress/training_publisher.py b/modelexpress_client/python/modelexpress/training_publisher.py index ad69e051..2848ee6d 100644 --- a/modelexpress_client/python/modelexpress/training_publisher.py +++ b/modelexpress_client/python/modelexpress/training_publisher.py @@ -67,8 +67,15 @@ def __init__( self._worker_id: str = str(uuid.uuid4()) self._mx_source_id: str | None = None self._model_name: str = "" + self._training_framework: str = "unknown" self._initialized = False self._registered = False + # Tracks which publish path has been used. publish_weights and + # publish_layer are mutually exclusive within a publisher's lifetime + # because they hold different sets of tensors registered with NIXL β€” + # mixing them silently invalidates the cached registration. None + # until first publish; "weights" or "layer" thereafter. + self._publish_mode: str | None = None @property def mx_source_id(self) -> str | None: @@ -85,11 +92,20 @@ def initialize( pipeline_parallel_size: int = 1, expert_parallel_size: int = 1, dtype: str = "bfloat16", + training_framework: str = "unknown", ) -> None: """Initialize NIXL agent and MX client. Must be called before any publish operations. Sets up the source identity that inference workers will use to filter compatible sources. + + Args: + training_framework: Identifier for the framework driving this + publisher (``"prime_rl"``, ``"verl"``, ``"nemo_rl"``, ...). + Surfaced in ``SourceIdentity.extra_parameters`` so consumers + can disambiguate sources from different frameworks publishing + to the same MX Server. Default ``"unknown"`` is intentional β€” + callers should pass an explicit value. """ if not is_nixl_available(): raise RuntimeError( @@ -97,6 +113,7 @@ def initialize( ) self._model_name = model_name + self._training_framework = training_framework self._identity_kwargs = dict( model_name=model_name, mx_source_type=p2p_pb2.MX_SOURCE_TYPE_WEIGHTS, @@ -118,7 +135,8 @@ def initialize( self._initialized = True logger.info( f"MxTrainingPublisher initialized: agent={self._agent_name}, " - f"device={self._device_id}, model={model_name}" + f"device={self._device_id}, model={model_name}, " + f"framework={training_framework}" ) def _build_identity(self, step: int) -> p2p_pb2.SourceIdentity: @@ -126,7 +144,7 @@ def _build_identity(self, step: int) -> p2p_pb2.SourceIdentity: return p2p_pb2.SourceIdentity( extra_parameters={ "training_step": str(step), - "training_framework": "prime_rl", + "training_framework": self._training_framework, }, **self._identity_kwargs, ) @@ -170,6 +188,16 @@ def publish_weights( """ if not self._initialized: raise RuntimeError("Call initialize() before publish_weights()") + if self._publish_mode == "layer": + raise RuntimeError( + "publish_weights() and publish_layer() are mutually exclusive: " + "this publisher has already been used in 'layer' mode " + "(publish_layer was called previously). Mixing the two paths " + "leaves NIXL holding only the most recently registered tensor " + "set, which silently invalidates earlier publishes. Use one " + "mode per publisher lifetime." + ) + self._publish_mode = "weights" if not self._registered: self._nixl.register_tensors(named_tensors) @@ -228,6 +256,14 @@ def publish_layer( """ if not self._initialized: raise RuntimeError("Call initialize() before publish_layer()") + if self._publish_mode == "weights": + raise RuntimeError( + "publish_layer() and publish_weights() are mutually exclusive: " + "this publisher has already been used in 'weights' mode " + "(publish_weights was called previously). See publish_weights " + "for the full explanation." + ) + self._publish_mode = "layer" self._nixl.register_tensors(layer_state_dict) metadata = self._nixl.nixl_metadata From dc2856dbf645db9f6c0a6a30c2b6677e9ab6d5ab Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Wed, 29 Apr 2026 11:44:57 -0700 Subject: [PATCH 21/25] docs(RL): add NIXL compression study reproduction guide MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds NIXL_COMPRESSION_STUDY.md to help the NIXL nvCOMP compression team reproduce our RL weight-transfer payloads using our validated PRIME-RL and verl workflows. Three paths documented: 1. Pre-captured data (fastest) β€” pointer to our existing Qwen2.5-1.5B data package (model.safetensors + pre/post RL weights + deltas + KV cache, captured from live GB200 deployment). 2. End-to-end reproduction on GB200 via the PRIME-RL overlay (PR PrimeIntellect-ai/prime-rl#2343) β€” deploy scenario A, exec into trainer pod, capture state_dict + simulate one RL step + dump KV cache. Step-by-step with kubectl commands. 3. Reproduction via verl MxCheckpointEngine (PR ai-dynamo/modelexpress #252) β€” same tensor content, different transport path. Also covers: compression-relevant properties table, per-tensor layout for Qwen3-0.6B and Qwen2.5-1.5B, delta analysis notes (BF16 deltas mostly zero at RL learning rates; FP32 diffs are the meaningful analysis target), NIXL integration point for nvCOMP (transparent β€” compress/decompress at the NIXL layer, no MX or framework changes), and a model-size scaling table for larger captures. Signed-off-by: Kavin Krishnan Made-with: Cursor --- docs/RL/NIXL_COMPRESSION_STUDY.md | 279 ++++++++++++++++++++++++++++++ 1 file changed, 279 insertions(+) create mode 100644 docs/RL/NIXL_COMPRESSION_STUDY.md diff --git a/docs/RL/NIXL_COMPRESSION_STUDY.md b/docs/RL/NIXL_COMPRESSION_STUDY.md new file mode 100644 index 00000000..9bbed3ad --- /dev/null +++ b/docs/RL/NIXL_COMPRESSION_STUDY.md @@ -0,0 +1,279 @@ +# NIXL nvCOMP Compression Study β€” Reproducing with ModelExpress RL Workflows + +**Last Updated**: April 29, 2026 +**Audience**: NIXL compression team (`eschmidt@nvidia.com`) +**Purpose**: Guide the NIXL team to capture and study real RL weight-transfer payloads using our validated PRIME-RL and verl workflows with ModelExpress (MX). + +--- + +## Background + +The NIXL team is evaluating nvCOMP GPU compression on the tensors that flow through NIXL during RL post-training. There are two transfer types: + +1. **RL refit** (training β†’ inference): full model weights, every RL step. +2. **KV cache** (prefill β†’ decode): per-request KV tensors in disaggregated inference. + +We have **two validated end-to-end RL workflows** that produce these payloads over NIXL on GB200: + +| Workflow | Framework | Status | PR | What it exercises | +|----------|-----------|--------|-----|-------------------| +| **PRIME-RL overlay** | PRIME-RL + vLLM | Scenarios A/B/C green on GB200 (20/20 steps each) | [PrimeIntellect-ai/prime-rl#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343) | NIXL RDMA weight push via PI's `NIXLWeightBroadcast` + `TransportPlan`, MX-mediated discovery | +| **verl MxCheckpointEngine** | verl + vLLM | 10 steps green on GB200 | [ai-dynamo/modelexpress#252](https://github.com/ai-dynamo/modelexpress/pull/252) | NIXL RDMA weight transfer via `MxCheckpointEngine` (`CheckpointEngine` plugin) | + +Both produce the **exact same kind of data** the NIXL team requested: raw BF16 weight tensors flowing GPU-to-GPU over NIXL, plus pre/post RL-step weight deltas for delta-compression analysis. + +--- + +## Option 1: Use pre-captured data (fastest) + +We have a ready-made data package captured from a live PRIME-RL deployment on GB200: + +```text +recovery/reinforcement learning/nixl_compression_data/RL_Qwen25/ +β”œβ”€β”€ model.safetensors # 2.9 GB β€” all 338 weight tensors (BF16) +β”œβ”€β”€ weights_pre_rl.safetensors # 3.4 GB β€” weights before optimizer.step() +β”œβ”€β”€ weights_post_rl.safetensors # 3.4 GB β€” weights after 1 AdamW step (lr=5e-6) +β”œβ”€β”€ weight_deltas.safetensors # 3.4 GB β€” elementwise diff (post - pre), BF16 +β”œβ”€β”€ kv_cache/ # 14 MB β€” 56 KV tensors from a 501-token prefill +β”‚ β”œβ”€β”€ layer_0_key.bin # shape [1, 2, 501, 128], BF16 +β”‚ β”œβ”€β”€ layer_0_value.bin +β”‚ β”œβ”€β”€ ... +β”‚ └── manifest.json # per-tensor metadata +β”œβ”€β”€ manifest.json # 66 KB β€” per-weight-tensor metadata +└── README.md # full layout + compression properties +``` + +**Model**: Qwen2.5-1.5B BF16, 28 layers, 1.54B parameters. + +**How to read**: + +```python +from safetensors import safe_open +import torch + +# Weights (the exact tensors NIXL transfers during RL refit) +with safe_open("model.safetensors", framework="pt") as f: + for key in f.keys(): + tensor = f.get_tensor(key) # torch.bfloat16 + raw_bytes = tensor.contiguous().untyped_storage() # raw bytes as on the wire + print(f"{key}: {tensor.shape}, {len(raw_bytes)} bytes") + +# Weight delta (for delta-compression analysis β€” compute in FP32 for precision) +pre = safe_open("weights_pre_rl.safetensors", framework="pt") +post = safe_open("weights_post_rl.safetensors", framework="pt") +for key in pre.keys(): + delta = post.get_tensor(key).float() - pre.get_tensor(key).float() + print(f"{key}: max_abs_delta={delta.abs().max():.2e}") + +# KV cache (the exact tensors transferred prefill β†’ decode via NIXL) +raw = open("kv_cache/layer_0_key.bin", "rb").read() +kv = torch.frombuffer(bytearray(raw), dtype=torch.bfloat16).reshape(1, 2, 501, 128) +``` + +**Key finding on deltas**: At BF16 precision, single-step RL deltas are mostly zero (AdamW updates at lr=5e-6 are below BF16's representable precision). For meaningful delta analysis, compute diffs in FP32. This suggests delta-compression should operate in FP32 and quantize back after. + +--- + +## Option 2: Reproduce end-to-end on GB200 (PRIME-RL overlay) + +Run our validated PRIME-RL overlay workflow and capture weights mid-flight. + +### Prerequisites + +- GKE cluster with GB200 nodes (ARM64, `customer-gpu-o7v` pool or equivalent) +- `kavin` namespace (or your own) with: + - MX Server running: `modelexpress-server..svc.cluster.local:8001` + - Redis backing the MX Server + - `shared-model-cache` PVC for HF model cache + - `nvcr-imagepullsecret` for pulling the overlay image +- `tsh` auth for `nvcr.io/nvidian/dynamo-dev/` + +### Step 1: Deploy the PRIME-RL overlay + +```bash +# Clone and check out the overlay branch +git clone git@github.com:KavinKrishnan/prime-rl.git +cd prime-rl +git checkout kavink/mx-on-nixl + +# Build the ARM64 image (or use the pre-built one) +# Pre-built: nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.2 +docker buildx build --platform linux/arm64 \ + -f docker/Dockerfile.mx-on-nixl \ + -t nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.2 \ + --push . + +# Deploy scenario A (baseline β€” PI's NIXL transport, no MX env vars) +cd k8s/prime-rl-mx-on-nixl +./run.sh deploy A + +# Watch until all 3 pods are Running +./run.sh status +``` + +### Step 2: Verify the RL loop is running + +```bash +# Trainer should show "Step N | Time: Xs" lines +kubectl -n kavin logs prime-rl-mx-on-nixl-trainer-0 --tail=20 | grep "SUCCESS.*Step" + +# Inference should show /update_weights 200 OK +kubectl -n kavin logs prime-rl-mx-on-nixl-inference-0 | grep "update_weights.*200" +``` + +### Step 3: Capture weights from the running trainer + +```bash +# Exec into the trainer pod +kubectl -n kavin exec -it prime-rl-mx-on-nixl-trainer-0 -- bash + +# Inside the pod β€” capture pre/post RL weights + KV cache +cd /tmp +/app/.venv/bin/python - << 'PYEOF' +import torch, json, os, time +from pathlib import Path +from transformers import AutoModelForCausalLM, AutoTokenizer +from safetensors.torch import save_file + +model_name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" +out = Path("/tmp/nixl_compression_capture") +out.mkdir(exist_ok=True) + +print("Loading model...") +model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="cpu") +tokenizer = AutoTokenizer.from_pretrained(model_name) + +# 1. Capture current weights (= what NIXL transfers during refit) +print("Saving current weights...") +sd = {k: v.clone() for k, v in model.state_dict().items()} +save_file(sd, str(out / "weights_current.safetensors")) + +# 2. Simulate one RL step for delta capture +print("Simulating one RL step...") +model.to("cuda:0") +model.train() +optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6) +inputs = tokenizer("The quick brown fox jumps over the lazy dog", return_tensors="pt").to("cuda:0") +loss = model(**inputs, labels=inputs["input_ids"]).loss +loss.backward() +optimizer.step() + +sd_post = {k: v.cpu().clone() for k, v in model.state_dict().items()} +save_file(sd_post, str(out / "weights_post_step.safetensors")) + +# 3. Compute delta +deltas = {} +for k in sd: + d = sd_post[k].float() - sd[k].float() + deltas[k] = d.to(torch.bfloat16) +save_file(deltas, str(out / "weight_deltas.safetensors")) + +# 4. KV cache from a prefill pass +print("Capturing KV cache...") +model.eval() +kv_out = out / "kv_cache" +kv_out.mkdir(exist_ok=True) +with torch.no_grad(): + outputs = model(**inputs, use_cache=True) +manifest = {"tensors": []} +for i, layer_kv in enumerate(outputs.past_key_values): + for j, name in enumerate(["key", "value"]): + t = layer_kv[j].cpu().contiguous() + fname = f"layer_{i}_{name}.bin" + (kv_out / fname).write_bytes(t.numpy().tobytes()) + manifest["tensors"].append({ + "name": f"layer_{i}_{name}", "shape": list(t.shape), + "dtype": "bfloat16", "size_bytes": t.numel() * 2, "file": fname + }) +json.dump(manifest, open(kv_out / "manifest.json", "w"), indent=2) + +# 5. Write weight manifest +w_manifest = {"model": model_name, "tensors": []} +for k, v in sd.items(): + w_manifest["tensors"].append({ + "name": k, "shape": list(v.shape), "dtype": str(v.dtype), + "size_bytes": v.numel() * v.element_size() + }) +json.dump(w_manifest, open(out / "manifest.json", "w"), indent=2) + +print(f"Done. Files in {out}") +PYEOF + +# Copy out of the pod +exit +kubectl -n kavin cp prime-rl-mx-on-nixl-trainer-0:/tmp/nixl_compression_capture ./nixl_capture +``` + +### Step 4: Tear down + +```bash +./run.sh clean +``` + +--- + +## Option 3: Reproduce with verl MxCheckpointEngine + +The verl integration uses the same MX client but through verl's `CheckpointEngine` plugin. This path captures the weights as they flow through `MxCheckpointEngine.send_weights()` / `receive_weights()`. + +Deployment docs: `docs/RL/VERL_MX_OVERVIEW.md` Β§6 in the modelexpress repo. + +The capture approach is the same as Option 2 (exec into the trainer pod, save state dict pre/post step) since the weight tensors are identical β€” both frameworks produce `model.named_parameters()` in BF16. The difference is the transport path (verl's bucket+ZMQ metadata vs prime-rl's TransportPlan+slot system), which doesn't affect the tensor content. + +--- + +## What to capture for the compression study + +| Artifact | File | Size (Qwen3-0.6B) | What it represents | +|----------|------|-------|---------------------| +| **Current weights** | `weights_current.safetensors` | ~1.2 GB | Exact tensors registered with NIXL and RDMA-written to inference GPU every RL step | +| **Post-step weights** | `weights_post_step.safetensors` | ~1.2 GB | After one AdamW step (lr=5e-6) | +| **Weight deltas** | `weight_deltas.safetensors` | ~1.2 GB | `post - pre` in BF16 (mostly zero β€” compute in FP32 for real deltas) | +| **KV cache** | `kv_cache/*.bin` | ~14 MB | Prefill output transferred to decode workers via NIXL | +| **Manifest** | `manifest.json` | ~30 KB | Per-tensor: name, shape, dtype, size_bytes | + +### Larger models for more representative data + +The steps above use Qwen3-0.6B (our scenario A model). For larger models closer to production: + +| Model | Params | Weight payload | Notes | +|-------|--------|----------------|-------| +| Qwen3-0.6B (above) | 0.6B | ~1.2 GB | Validated in PR #2343 scenarios A/B/C | +| Qwen2.5-1.5B | 1.5B | ~3 GB | Pre-captured data already available (see Option 1) | +| Qwen2.5-7B | 7.6B | ~15 GB | T1 model in our overlay plan | +| Qwen3-MoE (PI offered spec) | MoE | varies | Would exercise `ExpertSlot` + per-expert tensors β€” most representative for MoE compression | + +For models requiring multiple GPUs, the weights are FSDP-sharded β€” each rank's shard is `total / num_ranks` in size. The bytes on the wire per-rank are the shard size, not the full model. + +--- + +## Compression-relevant properties + +| Property | Weights | KV Cache | Delta (FP32) | +|----------|---------|----------|--------------| +| **Dtype on wire** | BF16 (2 B/elem) | BF16 (2 B/elem) | BF16 stored, but FP32 is the meaningful analysis dtype | +| **Value distribution** | Normal, centered ~0, std 0.01–0.1 | Wider, context-dependent | Very small magnitude (~1e-8 to 1e-6 per element) | +| **Sparsity** | Dense (no zeros) | Dense | ~100% zero at BF16 precision; structured-sparse at FP32 | +| **Best compression angle** | Entropy coding on mantissa bits | Temporal locality across layers | FP32 delta + entropy coding β€” high compressibility expected | +| **Transfer frequency** | Every RL step (~5–60 s) | Every request | Once for analysis | +| **Bucket size on wire** | 596 MB (measured in scenario A/B/C) | per-request, scales with seq_len | N/A | + +### NIXL integration point for nvCOMP + +If nvCOMP compression is added at the NIXL layer, the integration is transparent to both MX and the RL frameworks: + +```text +Current: + Training GPU β†’ NIXL register β†’ RDMA WRITE (raw bytes) β†’ Inference GPU + +With NIXL-layer nvCOMP: + Training GPU β†’ NIXL register β†’ nvCOMP compress (GPU) β†’ RDMA WRITE (compressed) β†’ nvCOMP decompress (GPU) β†’ Inference GPU +``` + +No changes to `MxTrainingPublisher`, `MxRefitReceiver`, `NIXLWeightBroadcast`, `TransportPlan`, or the MX Server protocol. Compression is internal to NIXL's transfer path. Our bucket-streaming pattern is preserved β€” compression happens per-bucket. + +--- + +## Questions? + +Reach out to Kavin Krishnan (`kavink@nvidia.com`) for access to the pre-captured data or help reproducing on a cluster. The PRIME-RL overlay branch (`KavinKrishnan/prime-rl:kavink/mx-on-nixl`) and the modelexpress RL branch (`ai-dynamo/modelexpress:kavink/RL`) are the entry points. From 5be0cefb2ca06af52dbeabe29f44b0867f773385 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Wed, 29 Apr 2026 11:47:29 -0700 Subject: [PATCH 22/25] docs(RL): clarify NIXL compression data package is request-only The pre-captured Qwen2.5-1.5B data package referenced in Option 1 of NIXL_COMPRESSION_STUDY.md isn't in this repo (binary tensors at GB scale aren't appropriate to commit) and the path I had previously shown was an internal local checkout. Replace with explicit "request from kavink@nvidia.com" framing and call out the appropriate channels (NV S3, internal share, or direct upload to eschmidt@nvidia.com per the original ask). Add the total package size (~14 GB) so the NIXL team knows what to expect bandwidth-wise. Update the "larger models" cross-reference accordingly. Signed-off-by: Kavin Krishnan Made-with: Cursor --- docs/RL/NIXL_COMPRESSION_STUDY.md | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/docs/RL/NIXL_COMPRESSION_STUDY.md b/docs/RL/NIXL_COMPRESSION_STUDY.md index 9bbed3ad..6680cac4 100644 --- a/docs/RL/NIXL_COMPRESSION_STUDY.md +++ b/docs/RL/NIXL_COMPRESSION_STUDY.md @@ -24,12 +24,14 @@ Both produce the **exact same kind of data** the NIXL team requested: raw BF16 w --- -## Option 1: Use pre-captured data (fastest) +## Option 1: Request the pre-captured data package (fastest) -We have a ready-made data package captured from a live PRIME-RL deployment on GB200: +We have a ready-made data package captured from a live PRIME-RL deployment on GB200. **It's not in this repo** (binary tensors at GB scale aren't appropriate to commit) β€” request access from `kavink@nvidia.com` and we'll share via the appropriate channel (NV S3 bucket, internal share, or direct upload to your `eschmidt@nvidia.com` inbox per the original request). + +Package contents: ```text -recovery/reinforcement learning/nixl_compression_data/RL_Qwen25/ +RL_Qwen25/ β”œβ”€β”€ model.safetensors # 2.9 GB β€” all 338 weight tensors (BF16) β”œβ”€β”€ weights_pre_rl.safetensors # 3.4 GB β€” weights before optimizer.step() β”œβ”€β”€ weights_post_rl.safetensors # 3.4 GB β€” weights after 1 AdamW step (lr=5e-6) @@ -43,7 +45,7 @@ recovery/reinforcement learning/nixl_compression_data/RL_Qwen25/ └── README.md # full layout + compression properties ``` -**Model**: Qwen2.5-1.5B BF16, 28 layers, 1.54B parameters. +**Model**: Qwen2.5-1.5B BF16, 28 layers, 1.54B parameters. ~14 GB total package size. **How to read**: @@ -239,7 +241,7 @@ The steps above use Qwen3-0.6B (our scenario A model). For larger models closer | Model | Params | Weight payload | Notes | |-------|--------|----------------|-------| | Qwen3-0.6B (above) | 0.6B | ~1.2 GB | Validated in PR #2343 scenarios A/B/C | -| Qwen2.5-1.5B | 1.5B | ~3 GB | Pre-captured data already available (see Option 1) | +| Qwen2.5-1.5B | 1.5B | ~3 GB | Pre-captured package available on request (see Option 1) | | Qwen2.5-7B | 7.6B | ~15 GB | T1 model in our overlay plan | | Qwen3-MoE (PI offered spec) | MoE | varies | Would exercise `ExpertSlot` + per-expert tensors β€” most representative for MoE compression | From 848e7f7731ace54fc170b20558567f054c0581a8 Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Wed, 29 Apr 2026 12:01:07 -0700 Subject: [PATCH 23/25] docs(RL): publish NIXL compression study capture scripts MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the two scripts that produced the Qwen2.5-1.5B data package we referenced in NIXL_COMPRESSION_STUDY.md so the NIXL team can reproduce captures themselves on different models / sequence lengths / clusters without going through us as a manual relay. New files: docs/RL/scripts/capture_weights_and_kv.py Standalone β€” any HF model, any host (CPU or single GPU), no cluster / RL framework needed. CLI flags for model, dtype, device, output dir, weights/KV-only modes, KV seq_len. docs/RL/scripts/capture_on_pod.py Inside-a-running-RL-pod variant. Generalized vs the original Qwen2.5-1.5B-only capture: --model, --out, --kv-seq-len, --lr flags. Captures pre/post RL weights + simulated AdamW step delta + KV cache in one pass. Produces the four-directory layout (weights_pre_rl/, weights_post_rl/, weight_deltas/, kv_cache/) we shipped to the compression team. docs/RL/scripts/README.md Quick reference for both scripts: when to use each, complete CLI examples, output layout, the BF16-deltas-are-mostly-zero note + FP32 analysis snippet, pointer back to the main study doc. Updated: docs/RL/NIXL_COMPRESSION_STUDY.md Option 2 now points at scripts/capture_on_pod.py with kubectl cp + exec invocation instead of an inlined heredoc Python block. Added Option-2-Step-4 ("standalone capture without a running RL deployment") pointing at capture_weights_and_kv.py for users who don't want to deploy the full overlay. The original on-disk capture scripts in our internal recovery directory are unchanged; this just publishes a generalized, flag-driven version of each into the public docs tree. Signed-off-by: Kavin Krishnan Made-with: Cursor --- docs/RL/NIXL_COMPRESSION_STUDY.md | 125 ++++------- docs/RL/scripts/README.md | 105 ++++++++++ docs/RL/scripts/capture_on_pod.py | 197 ++++++++++++++++++ docs/RL/scripts/capture_weights_and_kv.py | 242 ++++++++++++++++++++++ 4 files changed, 581 insertions(+), 88 deletions(-) create mode 100644 docs/RL/scripts/README.md create mode 100644 docs/RL/scripts/capture_on_pod.py create mode 100644 docs/RL/scripts/capture_weights_and_kv.py diff --git a/docs/RL/NIXL_COMPRESSION_STUDY.md b/docs/RL/NIXL_COMPRESSION_STUDY.md index 6680cac4..21efa2b2 100644 --- a/docs/RL/NIXL_COMPRESSION_STUDY.md +++ b/docs/RL/NIXL_COMPRESSION_STUDY.md @@ -78,7 +78,7 @@ kv = torch.frombuffer(bytearray(raw), dtype=torch.bfloat16).reshape(1, 2, 501, 1 ## Option 2: Reproduce end-to-end on GB200 (PRIME-RL overlay) -Run our validated PRIME-RL overlay workflow and capture weights mid-flight. +Run our validated PRIME-RL overlay workflow and capture weights mid-flight using the published [`scripts/`](./scripts/) directory. ### Prerequisites @@ -93,7 +93,6 @@ Run our validated PRIME-RL overlay workflow and capture weights mid-flight. ### Step 1: Deploy the PRIME-RL overlay ```bash -# Clone and check out the overlay branch git clone git@github.com:KavinKrishnan/prime-rl.git cd prime-rl git checkout kavink/mx-on-nixl @@ -108,105 +107,55 @@ docker buildx build --platform linux/arm64 \ # Deploy scenario A (baseline β€” PI's NIXL transport, no MX env vars) cd k8s/prime-rl-mx-on-nixl ./run.sh deploy A - -# Watch until all 3 pods are Running -./run.sh status +./run.sh status # wait until all 3 pods are Running ``` ### Step 2: Verify the RL loop is running ```bash -# Trainer should show "Step N | Time: Xs" lines kubectl -n kavin logs prime-rl-mx-on-nixl-trainer-0 --tail=20 | grep "SUCCESS.*Step" - -# Inference should show /update_weights 200 OK kubectl -n kavin logs prime-rl-mx-on-nixl-inference-0 | grep "update_weights.*200" ``` -### Step 3: Capture weights from the running trainer +### Step 3: Capture using the published script + +We ship `capture_on_pod.py` in [`scripts/`](./scripts/) β€” same script that produced our pre-captured Qwen2.5-1.5B package. It captures pre/post RL weights, simulates one AdamW step, computes deltas, and dumps a KV cache prefill, all in one pass. ```bash -# Exec into the trainer pod -kubectl -n kavin exec -it prime-rl-mx-on-nixl-trainer-0 -- bash - -# Inside the pod β€” capture pre/post RL weights + KV cache -cd /tmp -/app/.venv/bin/python - << 'PYEOF' -import torch, json, os, time -from pathlib import Path -from transformers import AutoModelForCausalLM, AutoTokenizer -from safetensors.torch import save_file - -model_name = "PrimeIntellect/Qwen3-0.6B-Reverse-Text-SFT" -out = Path("/tmp/nixl_compression_capture") -out.mkdir(exist_ok=True) - -print("Loading model...") -model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16, device_map="cpu") -tokenizer = AutoTokenizer.from_pretrained(model_name) - -# 1. Capture current weights (= what NIXL transfers during refit) -print("Saving current weights...") -sd = {k: v.clone() for k, v in model.state_dict().items()} -save_file(sd, str(out / "weights_current.safetensors")) - -# 2. Simulate one RL step for delta capture -print("Simulating one RL step...") -model.to("cuda:0") -model.train() -optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6) -inputs = tokenizer("The quick brown fox jumps over the lazy dog", return_tensors="pt").to("cuda:0") -loss = model(**inputs, labels=inputs["input_ids"]).loss -loss.backward() -optimizer.step() - -sd_post = {k: v.cpu().clone() for k, v in model.state_dict().items()} -save_file(sd_post, str(out / "weights_post_step.safetensors")) - -# 3. Compute delta -deltas = {} -for k in sd: - d = sd_post[k].float() - sd[k].float() - deltas[k] = d.to(torch.bfloat16) -save_file(deltas, str(out / "weight_deltas.safetensors")) - -# 4. KV cache from a prefill pass -print("Capturing KV cache...") -model.eval() -kv_out = out / "kv_cache" -kv_out.mkdir(exist_ok=True) -with torch.no_grad(): - outputs = model(**inputs, use_cache=True) -manifest = {"tensors": []} -for i, layer_kv in enumerate(outputs.past_key_values): - for j, name in enumerate(["key", "value"]): - t = layer_kv[j].cpu().contiguous() - fname = f"layer_{i}_{name}.bin" - (kv_out / fname).write_bytes(t.numpy().tobytes()) - manifest["tensors"].append({ - "name": f"layer_{i}_{name}", "shape": list(t.shape), - "dtype": "bfloat16", "size_bytes": t.numel() * 2, "file": fname - }) -json.dump(manifest, open(kv_out / "manifest.json", "w"), indent=2) - -# 5. Write weight manifest -w_manifest = {"model": model_name, "tensors": []} -for k, v in sd.items(): - w_manifest["tensors"].append({ - "name": k, "shape": list(v.shape), "dtype": str(v.dtype), - "size_bytes": v.numel() * v.element_size() - }) -json.dump(w_manifest, open(out / "manifest.json", "w"), indent=2) - -print(f"Done. Files in {out}") -PYEOF - -# Copy out of the pod -exit -kubectl -n kavin cp prime-rl-mx-on-nixl-trainer-0:/tmp/nixl_compression_capture ./nixl_capture +# Copy the script into the trainer pod +kubectl cp docs/RL/scripts/capture_on_pod.py \ + kavin/prime-rl-mx-on-nixl-trainer-0:/tmp/capture.py + +# Run it inside the pod (overlay image's interpreter is /app/.venv/bin/python) +kubectl exec kavin/prime-rl-mx-on-nixl-trainer-0 -- /app/.venv/bin/python /tmp/capture.py \ + --model Qwen/Qwen2.5-1.5B \ + --out /tmp/nixl_capture \ + --kv-seq-len 512 \ + --lr 5e-6 + +# Copy the results back +kubectl cp kavin/prime-rl-mx-on-nixl-trainer-0:/tmp/nixl_capture ./RL_capture ``` -### Step 4: Tear down +Output `RL_capture/` contains four sub-directories (`weights_pre_rl/`, `weights_post_rl/`, `weight_deltas/`, `kv_cache/`) each with raw `.bin` files plus a `manifest.json`. See [`scripts/README.md`](./scripts/README.md) for the full layout + flag reference. + +### Step 4 (optional): Capture without a running RL deployment + +If reproducing the overlay is more cluster work than the data is worth, [`scripts/capture_weights_and_kv.py`](./scripts/capture_weights_and_kv.py) is the **standalone** variant β€” works on any host (CPU or single GPU), no Kubernetes / RL framework required: + +```bash +pip install torch transformers safetensors + +python docs/RL/scripts/capture_weights_and_kv.py \ + --model Qwen/Qwen2.5-1.5B \ + --output-dir ./nixl_data \ + --dtype bfloat16 \ + --device cpu +``` + +Doesn't simulate an RL step (no pre/post/delta), but produces the same weight + KV cache layout the NIXL team can compress against. + +### Step 5: Tear down ```bash ./run.sh clean diff --git a/docs/RL/scripts/README.md b/docs/RL/scripts/README.md new file mode 100644 index 00000000..ce71366f --- /dev/null +++ b/docs/RL/scripts/README.md @@ -0,0 +1,105 @@ +# NIXL Compression Study β€” Capture Scripts + +Two scripts for producing the data described in [`../NIXL_COMPRESSION_STUDY.md`](../NIXL_COMPRESSION_STUDY.md). + +| Script | When to use | +|--------|-------------| +| `capture_weights_and_kv.py` | **Standalone** β€” capture from any HuggingFace model on any host. Doesn't require a running RL deployment. Just downloads the model and dumps weights + KV cache. CLI flags for model, dtype, device, output dir. | +| `capture_on_pod.py` | **Inside a running RL pod** β€” exec into a trainer pod and capture pre-step weights, simulate one AdamW step, capture post-step weights + delta + KV cache in one pass. Produces the four-directory layout we shipped to the NIXL team for Qwen2.5-1.5B. | + +## Standalone capture (any model, no cluster needed) + +```bash +pip install torch transformers safetensors + +# Smallest model β€” ~3 GB output, ~5 minutes total +python capture_weights_and_kv.py \ + --model Qwen/Qwen2.5-1.5B \ + --output-dir ./nixl_data \ + --dtype bfloat16 \ + --device cpu \ + --kv-seq-len 512 + +# Larger model +python capture_weights_and_kv.py \ + --model meta-llama/Llama-3.1-8B-Instruct \ + --output-dir ./nixl_data \ + --dtype bfloat16 \ + --device cuda:0 \ + --kv-seq-len 2048 + +# Weights only / KV only +python capture_weights_and_kv.py --model --output-dir --weights-only +python capture_weights_and_kv.py --model --output-dir --kv-only +``` + +Output layout: + +```text +/ +β”œβ”€β”€ weights// +β”‚ β”œβ”€β”€ tensors/*.bin # one file per parameter, raw bytes (BF16) +β”‚ └── manifest.json # name, shape, dtype, size, layer index, classification, stats +└── kvcache// + β”œβ”€β”€ layer_N_key.bin # one file per (layer, key/value) + β”œβ”€β”€ layer_N_value.bin + └── manifest.json +``` + +## Capture from a running pod (with RL-step simulation) + +This script is what produced the `RL_Qwen25/` package referenced in the NIXL request. It captures weights pre- and post- a simulated AdamW step, then computes the delta: + +```bash +# Copy the script into the pod +kubectl cp capture_on_pod.py /:/tmp/capture.py + +# Run it (uses /app/.venv/bin/python in our overlay image; adjust if different) +kubectl exec / -- /app/.venv/bin/python /tmp/capture.py \ + --model Qwen/Qwen2.5-1.5B \ + --out /tmp/nixl_capture \ + --kv-seq-len 512 \ + --lr 5e-6 + +# Copy results back +kubectl cp /:/tmp/nixl_capture ./RL_capture +``` + +Output layout (matches what we shipped to the NIXL team): + +```text +nixl_capture/ +β”œβ”€β”€ weights_pre_rl/ # pre-step weight tensors + manifest.json +β”œβ”€β”€ weights_post_rl/ # post-step weight tensors + manifest.json +β”œβ”€β”€ weight_deltas/ # post - pre (BF16; mostly zero β€” see note below) +└── kv_cache/ # one prefill pass output + manifest.json +``` + +### Note on BF16 deltas + +A single AdamW step at `lr=5e-6` produces parameter updates of magnitude ~1e-8 to 1e-6, which is **below BF16's representable precision** at typical weight magnitudes (0.01–0.1). The `weight_deltas/` files will therefore be mostly zero in BF16. + +For meaningful delta-compression analysis, compute the diff in FP32 from the pre/post safetensors: + +```python +import torch +from safetensors import safe_open + +with safe_open("weights_pre_rl/...", framework="pt") as pre, \ + safe_open("weights_post_rl/...", framework="pt") as post: + for k in pre.keys(): + delta_fp32 = post.get_tensor(k).float() - pre.get_tensor(k).float() + if delta_fp32.abs().max() > 0: + print(f"{k}: max_abs_delta={delta_fp32.abs().max():.2e}") +``` + +The pre/post tensors are saved as raw BF16 (the on-the-wire dtype). The FP32 delta is the meaningful analysis target β€” this is the signal nvCOMP would compress in a delta-transfer scheme. + +## What gets captured + +See [`../NIXL_COMPRESSION_STUDY.md`](../NIXL_COMPRESSION_STUDY.md) for the full breakdown of: + +- Per-tensor layout (Qwen3-0.6B and Qwen2.5-1.5B examples) +- KV cache shape + scaling table +- Compression-relevant properties +- Where these tensors fit in the NIXL transfer path diff --git a/docs/RL/scripts/capture_on_pod.py b/docs/RL/scripts/capture_on_pod.py new file mode 100644 index 00000000..5e3e510b --- /dev/null +++ b/docs/RL/scripts/capture_on_pod.py @@ -0,0 +1,197 @@ +#!/usr/bin/env python3 +"""Capture weight and KV cache data from a running PRIME-RL / verl deployment. + +Designed to be exec'd inside a trainer pod. Captures four artifacts in the +same shape we shipped to the NIXL nvCOMP compression team for Qwen2.5-1.5B: + + weights_pre_rl/ raw .bin tensors + manifest.json (pre-step state dict) + weights_post_rl/ raw .bin tensors + manifest.json (after one AdamW step) + weight_deltas/ raw .bin tensors + manifest.json (post - pre) + kv_cache/ raw .bin tensors + manifest.json (one prefill pass) + +Then a final summary line tells you how to `kubectl cp` it out. + +Usage (inside pod): + python3 capture_on_pod.py + python3 capture_on_pod.py --model Qwen/Qwen2.5-7B --out /tmp/nixl_capture --kv-seq-len 1024 + +Usage (from host, no pod): + kubectl cp capture_on_pod.py /:/tmp/capture.py + kubectl exec / -- python3 /tmp/capture.py --model + kubectl cp /:/tmp/nixl_capture ./RL_capture +""" +import argparse, json, os, time, torch +from pathlib import Path +from transformers import AutoModelForCausalLM, AutoTokenizer + +parser = argparse.ArgumentParser() +parser.add_argument("--model", default="Qwen/Qwen2.5-1.5B", + help="HuggingFace model name (must match the running RL deployment)") +parser.add_argument("--out", default="/tmp/nixl_capture", + help="Output directory inside the pod") +parser.add_argument("--kv-seq-len", type=int, default=512, + help="Sequence length for the KV cache prefill pass") +parser.add_argument("--lr", type=float, default=5e-6, + help="Learning rate for the simulated AdamW step (matches PRIME-RL default)") +args = parser.parse_args() + +MODEL = args.model +OUT = Path(args.out) +OUT.mkdir(parents=True, exist_ok=True) + +def tensor_stats(t): + ft = t.float() + return {"min": float(ft.min()), "max": float(ft.max()), "mean": float(ft.mean()), + "std": float(ft.std()), "abs_mean": float(ft.abs().mean()), + "zero_frac": float((t == 0).float().mean())} + +def classify(name): + for k, v in [("embed", "embedding"), ("lm_head", "lm_head"), ("norm", "norm"), + ("q_proj", "attn_q"), ("k_proj", "attn_k"), ("v_proj", "attn_v"), + ("o_proj", "attn_o"), ("gate_proj", "mlp_gate"), ("up_proj", "mlp_up"), + ("down_proj", "mlp_down")]: + if k in name: return v + return "other" + +def layer_idx(name): + parts = name.split(".") + for i, p in enumerate(parts): + if p == "layers" and i + 1 < len(parts) and parts[i+1].isdigit(): + return int(parts[i+1]) + return -1 + +print(f"Loading {MODEL}...") +tokenizer = AutoTokenizer.from_pretrained(MODEL, trust_remote_code=True) +model = AutoModelForCausalLM.from_pretrained(MODEL, torch_dtype=torch.bfloat16, trust_remote_code=True) +model.eval() + +# --- 1. Dump weights (pre-RL step) --- +print("\n=== Capturing pre-RL weights ===") +wdir = OUT / "weights_pre_rl" +wdir.mkdir(exist_ok=True) +manifest = {"model": MODEL, "dtype": "bfloat16", "capture": "pre_rl_weights", + "description": "Exact weight tensors transferred during RL refit (training->inference via NIXL RDMA)", + "tensors": []} +total = 0 +for name, param in model.named_parameters(): + t = param.data.contiguous() + raw = bytes(t.untyped_storage())[:t.numel() * t.element_size()] + fname = name.replace(".", "_") + ".bin" + (wdir / fname).write_bytes(raw) + manifest["tensors"].append({"name": name, "file": fname, "shape": list(t.shape), + "dtype": str(t.dtype), "size_bytes": len(raw), "numel": t.numel(), + "layer": layer_idx(name), "type": classify(name), "stats": tensor_stats(t)}) + total += len(raw) +manifest["total_bytes"] = total +manifest["total_gb"] = round(total / 1e9, 3) +manifest["num_tensors"] = len(manifest["tensors"]) +cfg = model.config +manifest["model_config"] = {"num_hidden_layers": cfg.num_hidden_layers, "hidden_size": cfg.hidden_size, + "intermediate_size": cfg.intermediate_size, "num_attention_heads": cfg.num_attention_heads, + "num_key_value_heads": cfg.num_key_value_heads, "vocab_size": cfg.vocab_size} +(wdir / "manifest.json").write_text(json.dumps(manifest, indent=2)) +print(f" {manifest['num_tensors']} tensors, {manifest['total_gb']} GB -> {wdir}") + +# --- 2. KV cache --- +print(f"\n=== Capturing KV cache (seq_len={args.kv_seq_len}) ===") +kvdir = OUT / "kv_cache" +kvdir.mkdir(exist_ok=True) +prompt = "The quick brown fox jumps over the lazy dog. " * (args.kv_seq_len // 10 + 1) +inputs = tokenizer(prompt, return_tensors="pt", max_length=args.kv_seq_len, truncation=True) +with torch.no_grad(): + outputs = model(**inputs, use_cache=True) +kv = outputs.past_key_values +kv_manifest = {"model": MODEL, "capture": "kv_cache_prefill", "seq_len": int(inputs["input_ids"].shape[1]), + "description": "KV cache from prefill pass - transferred prefill->decode via NIXL in disagg inference", + "tensors": []} +kv_total = 0 +for li, layer_kv in enumerate(kv): + for ki, kn in enumerate(["key", "value"]): + t = layer_kv[ki].contiguous() + raw = bytes(t.untyped_storage())[:t.numel() * t.element_size()] + fname = f"layer_{li}_{kn}.bin" + (kvdir / fname).write_bytes(raw) + kv_manifest["tensors"].append({"name": f"layer_{li}.{kn}", "file": fname, + "shape": list(t.shape), "dtype": str(t.dtype), "size_bytes": len(raw), + "layer": li, "kv_type": kn, "stats": tensor_stats(t)}) + kv_total += len(raw) +kv_manifest["total_bytes"] = kv_total +kv_manifest["total_mb"] = round(kv_total / 1e6, 3) +kv_manifest["kv_config"] = {"num_layers": len(kv), + "num_kv_heads": cfg.num_key_value_heads, + "head_dim": cfg.hidden_size // cfg.num_attention_heads} +(kvdir / "manifest.json").write_text(json.dumps(kv_manifest, indent=2)) +print(f" {len(kv_manifest['tensors'])} tensors, {kv_manifest['total_mb']} MB -> {kvdir}") + +# --- 3. Simulate one RL step and capture post-RL weights + delta --- +print(f"\n=== Simulating RL step (AdamW, lr={args.lr}, dummy loss) ===") +model.train() +optimizer = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=0.01) +dummy_input = tokenizer("Hello world", return_tensors="pt") +output = model(**dummy_input, labels=dummy_input["input_ids"]) +output.loss.backward() +optimizer.step() +optimizer.zero_grad() +model.eval() + +print("\n=== Capturing post-RL weights ===") +wdir2 = OUT / "weights_post_rl" +wdir2.mkdir(exist_ok=True) +manifest2 = {"model": MODEL, "dtype": "bfloat16", "capture": "post_rl_weights", + "description": "Weights after 1 RL optimizer step (lr=5e-6, AdamW)", "tensors": []} +total2 = 0 +for name, param in model.named_parameters(): + t = param.data.contiguous() + raw = bytes(t.untyped_storage())[:t.numel() * t.element_size()] + fname = name.replace(".", "_") + ".bin" + (wdir2 / fname).write_bytes(raw) + manifest2["tensors"].append({"name": name, "file": fname, "shape": list(t.shape), + "dtype": str(t.dtype), "size_bytes": len(raw), "numel": t.numel(), + "layer": layer_idx(name), "type": classify(name), "stats": tensor_stats(t)}) + total2 += len(raw) +manifest2["total_bytes"] = total2 +manifest2["total_gb"] = round(total2 / 1e9, 3) +(wdir2 / "manifest.json").write_text(json.dumps(manifest2, indent=2)) +print(f" {len(manifest2['tensors'])} tensors, {manifest2['total_gb']} GB -> {wdir2}") + +# --- 4. Compute and save deltas --- +print("\n=== Computing weight deltas (post - pre) ===") +ddir = OUT / "weight_deltas" +ddir.mkdir(exist_ok=True) +delta_manifest = {"model": MODEL, "capture": "weight_delta_1_step", + "description": ( + f"Difference between weights after 1 RL step vs before. " + f"RL uses lr={args.lr} so deltas are tiny β€” at BF16 most are exactly zero " + f"(below mantissa precision). For meaningful delta-compression analysis, " + f"compute diffs in FP32 from the pre/post safetensors instead of using " + f"this BF16-stored delta directly." + ), + "tensors": []} +dtotal = 0 +pre_files = {m["name"]: m["file"] for m in manifest["tensors"]} +for info in manifest2["tensors"]: + pre_raw = (OUT / "weights_pre_rl" / pre_files[info["name"]]).read_bytes() + post_raw = (wdir2 / info["file"]).read_bytes() + pre_t = torch.frombuffer(bytearray(pre_raw), dtype=torch.bfloat16).reshape(info["shape"]) + post_t = torch.frombuffer(bytearray(post_raw), dtype=torch.bfloat16).reshape(info["shape"]) + delta = post_t - pre_t + delta_raw = bytes(delta.contiguous().untyped_storage())[:delta.numel() * delta.element_size()] + fname = "delta_" + info["file"] + (ddir / fname).write_bytes(delta_raw) + delta_manifest["tensors"].append({"name": info["name"], "file": fname, + "shape": info["shape"], "dtype": "bfloat16", "size_bytes": len(delta_raw), + "stats": tensor_stats(delta)}) + dtotal += len(delta_raw) +delta_manifest["total_bytes"] = dtotal +delta_manifest["total_gb"] = round(dtotal / 1e9, 3) +(ddir / "manifest.json").write_text(json.dumps(delta_manifest, indent=2)) +print(f" {len(delta_manifest['tensors'])} delta tensors, {delta_manifest['total_gb']} GB -> {ddir}") + +# --- Summary --- +print(f"\n=== DONE ===") +print(f"Output: {OUT}") +print(f" weights_pre_rl/ : {manifest['total_gb']} GB ({manifest['num_tensors']} tensors)") +print(f" weights_post_rl/ : {manifest2['total_gb']} GB") +print(f" weight_deltas/ : {delta_manifest['total_gb']} GB") +print(f" kv_cache/ : {kv_manifest['total_mb']} MB ({len(kv_manifest['tensors'])} tensors)") +print(f"\nTo copy out: kubectl cp /:{OUT} ./RL_capture") diff --git a/docs/RL/scripts/capture_weights_and_kv.py b/docs/RL/scripts/capture_weights_and_kv.py new file mode 100644 index 00000000..38c6d215 --- /dev/null +++ b/docs/RL/scripts/capture_weights_and_kv.py @@ -0,0 +1,242 @@ +#!/usr/bin/env python3 +"""Capture raw model weights and KV cache data for NIXL nvCOMP compression study. + +Outputs: + weights/{model_name}/tensors/*.bin β€” raw weight tensors (as transferred during RL refit) + weights/{model_name}/manifest.json β€” per-tensor metadata + kvcache/{model_name}/*.bin β€” KV cache tensors from a sample forward pass + kvcache/{model_name}/manifest.json β€” per-KV-tensor metadata + +Usage: + python capture_weights_and_kv.py --model Qwen/Qwen2.5-1.5B --output-dir ./nixl_data + python capture_weights_and_kv.py --model Qwen/Qwen2.5-1.5B --output-dir ./nixl_data --kv-only + python capture_weights_and_kv.py --model Qwen/Qwen2.5-1.5B --output-dir ./nixl_data --weights-only +""" + +import argparse +import json +import os +import time +from pathlib import Path + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +def sanitize_name(name: str) -> str: + return name.replace("/", "_").replace(".", "_") + + +def tensor_stats(t: torch.Tensor) -> dict: + with torch.no_grad(): + ft = t.float() + return { + "min": float(ft.min()), + "max": float(ft.max()), + "mean": float(ft.mean()), + "std": float(ft.std()), + "abs_mean": float(ft.abs().mean()), + "zero_fraction": float((t == 0).float().mean()), + } + + +def classify_tensor(name: str) -> str: + if "embed" in name: + return "embedding" + if "lm_head" in name: + return "lm_head" + if "layernorm" in name or "norm" in name: + return "norm" + if "q_proj" in name: + return "attention_q" + if "k_proj" in name: + return "attention_k" + if "v_proj" in name: + return "attention_v" + if "o_proj" in name: + return "attention_o" + if "gate_proj" in name or "w1" in name: + return "mlp_gate" + if "up_proj" in name or "w3" in name: + return "mlp_up" + if "down_proj" in name or "w2" in name: + return "mlp_down" + return "other" + + +def get_layer_index(name: str) -> int: + parts = name.split(".") + for i, part in enumerate(parts): + if part == "layers" and i + 1 < len(parts) and parts[i + 1].isdigit(): + return int(parts[i + 1]) + return -1 + + +def capture_weights(model, model_name: str, output_dir: Path, dtype_name: str): + """Dump all model weight tensors as raw binary files + manifest.""" + safe_name = sanitize_name(model_name) + weight_dir = output_dir / "weights" / safe_name / "tensors" + weight_dir.mkdir(parents=True, exist_ok=True) + + manifest = { + "model_name": model_name, + "dtype": dtype_name, + "capture_type": "model_weights_for_rl_refit", + "description": ( + "These are the exact weight tensors transferred from training GPUs to " + "inference GPUs during the RL refit (weight sync) phase. In RL post-training, " + "after each optimizer step, the full model state dict is gathered and sent to " + "the inference engine (vLLM). These tensors represent that payload." + ), + "tensors": [], + } + + total_bytes = 0 + for name, param in model.named_parameters(): + t = param.data.contiguous().cpu() + raw_bytes = bytes(t.untyped_storage())[:t.numel() * t.element_size()] + fname = sanitize_name(name) + ".bin" + (weight_dir / fname).write_bytes(raw_bytes) + + info = { + "name": name, + "file": f"tensors/{fname}", + "shape": list(t.shape), + "dtype": str(t.dtype), + "size_bytes": len(raw_bytes), + "numel": t.numel(), + "layer_index": get_layer_index(name), + "tensor_type": classify_tensor(name), + "stats": tensor_stats(t), + } + manifest["tensors"].append(info) + total_bytes += len(raw_bytes) + + manifest["total_tensors"] = len(manifest["tensors"]) + manifest["total_bytes"] = total_bytes + manifest["total_gb"] = round(total_bytes / 1e9, 3) + + config = model.config + manifest["model_config"] = { + "num_hidden_layers": getattr(config, "num_hidden_layers", None), + "hidden_size": getattr(config, "hidden_size", None), + "intermediate_size": getattr(config, "intermediate_size", None), + "num_attention_heads": getattr(config, "num_attention_heads", None), + "num_key_value_heads": getattr(config, "num_key_value_heads", None), + "vocab_size": getattr(config, "vocab_size", None), + "max_position_embeddings": getattr(config, "max_position_embeddings", None), + "architecture": config.architectures[0] if hasattr(config, "architectures") and config.architectures else None, + } + + manifest_path = output_dir / "weights" / safe_name / "manifest.json" + manifest_path.write_text(json.dumps(manifest, indent=2)) + print(f"Weights captured: {len(manifest['tensors'])} tensors, {manifest['total_gb']} GB β†’ {weight_dir}") + return manifest + + +def capture_kvcache(model, tokenizer, model_name: str, output_dir: Path, seq_len: int = 512): + """Run a forward pass and capture the KV cache tensors.""" + safe_name = sanitize_name(model_name) + kv_dir = output_dir / "kvcache" / safe_name + kv_dir.mkdir(parents=True, exist_ok=True) + + prompt = "The quick brown fox jumps over the lazy dog. " * (seq_len // 10) + inputs = tokenizer(prompt, return_tensors="pt", max_length=seq_len, truncation=True) + + device = next(model.parameters()).device + inputs = {k: v.to(device) for k, v in inputs.items()} + + with torch.no_grad(): + outputs = model(**inputs, use_cache=True) + + past_kv = outputs.past_key_values + + manifest = { + "model_name": model_name, + "capture_type": "kv_cache_prefill_to_decode", + "description": ( + "These are the KV cache tensors produced during the prefill phase and " + "sent to decode workers in disaggregated inference (prefill/decode split). " + "In NIXL-based KV transfer, these are the exact tensors transferred " + "GPU-to-GPU between prefill and decode nodes." + ), + "sequence_length": int(inputs["input_ids"].shape[1]), + "batch_size": 1, + "tensors": [], + } + + total_bytes = 0 + for layer_idx, layer_kv in enumerate(past_kv): + for kv_idx, kv_name in enumerate(["key", "value"]): + t = layer_kv[kv_idx].contiguous().cpu() + raw_bytes = bytes(t.untyped_storage())[:t.numel() * t.element_size()] + fname = f"layer_{layer_idx}_{kv_name}.bin" + (kv_dir / fname).write_bytes(raw_bytes) + + info = { + "name": f"layer_{layer_idx}.{kv_name}", + "file": fname, + "shape": list(t.shape), + "dtype": str(t.dtype), + "size_bytes": len(raw_bytes), + "layer_index": layer_idx, + "kv_type": kv_name, + "stats": tensor_stats(t), + } + manifest["tensors"].append(info) + total_bytes += len(raw_bytes) + + manifest["total_tensors"] = len(manifest["tensors"]) + manifest["total_bytes"] = total_bytes + manifest["total_mb"] = round(total_bytes / 1e6, 3) + + config = model.config + manifest["kv_config"] = { + "num_layers": len(past_kv), + "num_kv_heads": getattr(config, "num_key_value_heads", getattr(config, "num_attention_heads", None)), + "head_dim": getattr(config, "hidden_size", 0) // getattr(config, "num_attention_heads", 1), + "kv_dtype": str(past_kv[0][0].dtype), + } + + manifest_path = kv_dir / "manifest.json" + manifest_path.write_text(json.dumps(manifest, indent=2)) + print(f"KV cache captured: {len(manifest['tensors'])} tensors, {manifest['total_mb']} MB β†’ {kv_dir}") + return manifest + + +def main(): + parser = argparse.ArgumentParser(description="Capture weight and KV cache data for NIXL compression study") + parser.add_argument("--model", type=str, default="Qwen/Qwen2.5-1.5B", help="HuggingFace model name") + parser.add_argument("--output-dir", type=str, default="./nixl_compression_data", help="Output directory") + parser.add_argument("--dtype", type=str, default="bfloat16", choices=["bfloat16", "float16", "float32"]) + parser.add_argument("--device", type=str, default="cpu", help="Device to load model on (cpu or cuda:0)") + parser.add_argument("--weights-only", action="store_true", help="Only capture weights, skip KV cache") + parser.add_argument("--kv-only", action="store_true", help="Only capture KV cache, skip weights") + parser.add_argument("--kv-seq-len", type=int, default=512, help="Sequence length for KV cache capture") + args = parser.parse_args() + + output_dir = Path(args.output_dir) + dtype_map = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32} + dtype = dtype_map[args.dtype] + + print(f"Loading model: {args.model} (dtype={args.dtype}, device={args.device})") + tokenizer = AutoTokenizer.from_pretrained(args.model, trust_remote_code=True) + model = AutoModelForCausalLM.from_pretrained( + args.model, torch_dtype=dtype, trust_remote_code=True, device_map=args.device, + ) + model.eval() + + if not args.kv_only: + print("\n=== Capturing model weights (RL refit payload) ===") + capture_weights(model, args.model, output_dir, args.dtype) + + if not args.weights_only: + print(f"\n=== Capturing KV cache (prefillβ†’decode, seq_len={args.kv_seq_len}) ===") + capture_kvcache(model, tokenizer, args.model, output_dir, seq_len=args.kv_seq_len) + + print(f"\nDone. Data written to {output_dir}/") + print("Send the output directory to the NIXL compression team.") + + +if __name__ == "__main__": + main() From e8aefab698c741b4099cba6cb0796c3af886f84d Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Wed, 29 Apr 2026 12:11:25 -0700 Subject: [PATCH 24/25] docs(RL): genericize NIXL compression study + add component diagram MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Two changes to NIXL_COMPRESSION_STUDY.md: 1. Add a component-view mermaid diagram at the top showing where the compression-target tensors actually live (RL refit edge between trainer and inference NIXL agents; KV cache edge between prefill and decode), with green nodes / edges marking the compression surface and purple marking RL-stack infrastructure that wouldn't change if nvCOMP slots into the NIXL layer transparently. 2. Drop GKE/cluster-specific assumptions. Previously Option 2 named a specific GKE node pool, namespace, registry, and tsh auth flow as prerequisites; now it just says "a GB200 cluster (ARM64) with at least 2 nodes, container runtime, RDMA-capable interconnect". The K8s manifests are flagged as examples that need light edits (ns, node selectors, registry, RDMA network annotations) per cluster. Hardcoded "kavin" namespace replaced with $NS= throughout the kubectl commands so a copy-paste of the recipe works on any cluster. The capture flow itself was already cluster-agnostic β€” these edits just stop the doc reading like it's only reproducible on our exact GKE shape. Signed-off-by: Kavin Krishnan Made-with: Cursor --- docs/RL/NIXL_COMPRESSION_STUDY.md | 105 +++++++++++++++++++++++++----- 1 file changed, 89 insertions(+), 16 deletions(-) diff --git a/docs/RL/NIXL_COMPRESSION_STUDY.md b/docs/RL/NIXL_COMPRESSION_STUDY.md index 21efa2b2..25f54d6d 100644 --- a/docs/RL/NIXL_COMPRESSION_STUDY.md +++ b/docs/RL/NIXL_COMPRESSION_STUDY.md @@ -24,6 +24,74 @@ Both produce the **exact same kind of data** the NIXL team requested: raw BF16 w --- +## Component View + +Where the data being studied actually lives + what writes/reads it. Green is what the compression team would compress; purple is the existing RL stack producing it. + +```mermaid +flowchart TB + subgraph trainer_node["Trainer node β€” GB200"] + direction TB + T_FSDP["FSDP2 trainer
optimizer.step()"] + T_PUB["MxTrainingPublisher
+ NIXLWeightBroadcast"] + T_NIXL(["NIXL agent (UCX rc_mlx5)"]) + T_FSDP --> T_PUB --> T_NIXL + end + + subgraph mx_meta["Metadata plane β€” MX Server"] + MX["MX Server
(gRPC)"] + REDIS[("Redis")] + MX --> REDIS + end + + subgraph inference_node["Inference node β€” GB200"] + direction TB + I_NIXL(["NIXL agent (UCX rc_mlx5)"]) + I_RECV["MxRefitReceiver
+ NIXLWeightUpdateWorker"] + VLLM["vLLM engine
(live params)"] + I_NIXL --> I_RECV --> VLLM + end + + subgraph prefill_node["Prefill worker β€” GB200 (disagg inference)"] + direction TB + P_FWD["vLLM prefill pass"] + P_KV["KV cache buffer
(per-layer key/value)"] + P_FWD --> P_KV + end + + subgraph decode_node["Decode worker β€” GB200"] + direction TB + D_KV["KV cache import"] + D_GEN["Token generation"] + D_KV --> D_GEN + end + + T_PUB -. "publish metadata
(SourceIdentity, agent blob)" .-> MX + MX -. "discover" .-> I_RECV + + T_NIXL ==> |"β‘  RL REFIT
weights (BF16)
~3 GB / step (1.5B model)
~140 GB / step (70B)"| I_NIXL + P_KV ==> |"β‘‘ KV CACHE
tensors (BF16)
~14 MB at seq=512
~3.5 GB at seq=131K"| D_KV + + style T_FSDP fill:#533483,stroke:#e94560,color:#fff + style T_PUB fill:#533483,stroke:#e94560,color:#fff + style T_NIXL fill:#1b5e20,stroke:#4caf50,color:#fff + style I_NIXL fill:#1b5e20,stroke:#4caf50,color:#fff + style I_RECV fill:#533483,stroke:#e94560,color:#fff + style VLLM fill:#533483,stroke:#e94560,color:#fff + style P_FWD fill:#533483,stroke:#e94560,color:#fff + style P_KV fill:#1b5e20,stroke:#4caf50,color:#fff + style D_KV fill:#1b5e20,stroke:#4caf50,color:#fff + style D_GEN fill:#533483,stroke:#e94560,color:#fff + style MX fill:#533483,stroke:#e94560,color:#fff + style REDIS fill:#162447,stroke:#533483,color:#e0e0e0 +``` + +**Compression target = the green edges.** β‘  is the RL-refit path between trainer and inference NIXL agents; β‘‘ is the KV cache transfer between prefill and decode. Everything purple β€” the trainer, the MX Server, vLLM, the receiver β€” is RL-stack infrastructure that wouldn't change if nvCOMP is added at the NIXL layer (compression would slot in transparently between `register` and `RDMA WRITE` on either edge). + +The capture scripts in [`scripts/`](./scripts/) snapshot the bytes that cross those green edges, plus pre/post weight tensors for delta-compression analysis. + +--- + ## Option 1: Request the pre-captured data package (fastest) We have a ready-made data package captured from a live PRIME-RL deployment on GB200. **It's not in this repo** (binary tensors at GB scale aren't appropriate to commit) β€” request access from `kavink@nvidia.com` and we'll share via the appropriate channel (NV S3 bucket, internal share, or direct upload to your `eschmidt@nvidia.com` inbox per the original request). @@ -82,13 +150,14 @@ Run our validated PRIME-RL overlay workflow and capture weights mid-flight using ### Prerequisites -- GKE cluster with GB200 nodes (ARM64, `customer-gpu-o7v` pool or equivalent) -- `kavin` namespace (or your own) with: - - MX Server running: `modelexpress-server..svc.cluster.local:8001` +- A GB200 cluster (ARM64) with at least 2 nodes, container runtime, and an RDMA-capable interconnect (InfiniBand or RoCE) between nodes. Cluster orchestration is Kubernetes-based; manifests assume `kubectl` access and a working namespace. +- A namespace where you'll deploy the overlay, with the following bound: + - MX Server reachable at `modelexpress-server..svc.cluster.local:8001` (Helm chart in this repo, or use an existing deployment) - Redis backing the MX Server - - `shared-model-cache` PVC for HF model cache - - `nvcr-imagepullsecret` for pulling the overlay image -- `tsh` auth for `nvcr.io/nvidian/dynamo-dev/` + - A shared model-cache PVC for HuggingFace downloads + - An image pull secret for the registry hosting the overlay image (we publish to `nvcr.io/nvidian/dynamo-dev/`; you can also build locally) + +The included K8s manifests under `prime-rl/k8s/prime-rl-mx-on-nixl/` may need light edits (namespace, node selectors, image pull secret name, RDMA network annotations) for your cluster β€” they're examples, not portable across all GB200 deployments. The capture flow itself is cluster-agnostic. ### Step 1: Deploy the PRIME-RL overlay @@ -97,24 +166,26 @@ git clone git@github.com:KavinKrishnan/prime-rl.git cd prime-rl git checkout kavink/mx-on-nixl -# Build the ARM64 image (or use the pre-built one) +# Use the pre-built ARM64 image, or build locally # Pre-built: nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.2 docker buildx build --platform linux/arm64 \ -f docker/Dockerfile.mx-on-nixl \ - -t nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.2 \ + -t /prime-rl-mx-on-nixl:v0.2 \ --push . -# Deploy scenario A (baseline β€” PI's NIXL transport, no MX env vars) +# Edit k8s/prime-rl-mx-on-nixl/*.yaml for your namespace, node selectors, +# RDMA network annotations, and image registry. Then: cd k8s/prime-rl-mx-on-nixl -./run.sh deploy A -./run.sh status # wait until all 3 pods are Running +./run.sh deploy A # scenario A = PI's NIXL transport, no MX env vars +./run.sh status # wait until all 3 pods are Running ``` ### Step 2: Verify the RL loop is running ```bash -kubectl -n kavin logs prime-rl-mx-on-nixl-trainer-0 --tail=20 | grep "SUCCESS.*Step" -kubectl -n kavin logs prime-rl-mx-on-nixl-inference-0 | grep "update_weights.*200" +NS= +kubectl -n $NS logs prime-rl-mx-on-nixl-trainer-0 --tail=20 | grep "SUCCESS.*Step" +kubectl -n $NS logs prime-rl-mx-on-nixl-inference-0 | grep "update_weights.*200" ``` ### Step 3: Capture using the published script @@ -122,19 +193,21 @@ kubectl -n kavin logs prime-rl-mx-on-nixl-inference-0 | grep "update_weights.*20 We ship `capture_on_pod.py` in [`scripts/`](./scripts/) β€” same script that produced our pre-captured Qwen2.5-1.5B package. It captures pre/post RL weights, simulates one AdamW step, computes deltas, and dumps a KV cache prefill, all in one pass. ```bash +NS= + # Copy the script into the trainer pod kubectl cp docs/RL/scripts/capture_on_pod.py \ - kavin/prime-rl-mx-on-nixl-trainer-0:/tmp/capture.py + $NS/prime-rl-mx-on-nixl-trainer-0:/tmp/capture.py # Run it inside the pod (overlay image's interpreter is /app/.venv/bin/python) -kubectl exec kavin/prime-rl-mx-on-nixl-trainer-0 -- /app/.venv/bin/python /tmp/capture.py \ +kubectl exec $NS/prime-rl-mx-on-nixl-trainer-0 -- /app/.venv/bin/python /tmp/capture.py \ --model Qwen/Qwen2.5-1.5B \ --out /tmp/nixl_capture \ --kv-seq-len 512 \ --lr 5e-6 # Copy the results back -kubectl cp kavin/prime-rl-mx-on-nixl-trainer-0:/tmp/nixl_capture ./RL_capture +kubectl cp $NS/prime-rl-mx-on-nixl-trainer-0:/tmp/nixl_capture ./RL_capture ``` Output `RL_capture/` contains four sub-directories (`weights_pre_rl/`, `weights_post_rl/`, `weight_deltas/`, `kv_cache/`) each with raw `.bin` files plus a `manifest.json`. See [`scripts/README.md`](./scripts/README.md) for the full layout + flag reference. From 16ce4feddecb66a2b5b3eb1ef6bf51cc48203d1e Mon Sep 17 00:00:00 2001 From: Kavin Krishnan Date: Wed, 29 Apr 2026 14:45:12 -0700 Subject: [PATCH 25/25] docs(RL): drop named NIXL inbox from compression study doc MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Removes both references to eschmidt@nvidia.com from NIXL_COMPRESSION_STUDY.md so the guide reads as a general team-facing doc rather than addressed at one inbox. Audience line now just says "NIXL compression team"; Option 1 channel list trims "direct upload to your eschmidt@nvidia.com inbox per the original request" down to "direct upload" β€” same channel options, no person-specific routing. Single contact for the data package remains kavink@nvidia.com. Signed-off-by: Kavin Krishnan Made-with: Cursor --- docs/RL/NIXL_COMPRESSION_STUDY.md | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/docs/RL/NIXL_COMPRESSION_STUDY.md b/docs/RL/NIXL_COMPRESSION_STUDY.md index 25f54d6d..e5df9cd4 100644 --- a/docs/RL/NIXL_COMPRESSION_STUDY.md +++ b/docs/RL/NIXL_COMPRESSION_STUDY.md @@ -1,7 +1,7 @@ # NIXL nvCOMP Compression Study β€” Reproducing with ModelExpress RL Workflows **Last Updated**: April 29, 2026 -**Audience**: NIXL compression team (`eschmidt@nvidia.com`) +**Audience**: NIXL compression team **Purpose**: Guide the NIXL team to capture and study real RL weight-transfer payloads using our validated PRIME-RL and verl workflows with ModelExpress (MX). --- @@ -94,7 +94,7 @@ The capture scripts in [`scripts/`](./scripts/) snapshot the bytes that cross th ## Option 1: Request the pre-captured data package (fastest) -We have a ready-made data package captured from a live PRIME-RL deployment on GB200. **It's not in this repo** (binary tensors at GB scale aren't appropriate to commit) β€” request access from `kavink@nvidia.com` and we'll share via the appropriate channel (NV S3 bucket, internal share, or direct upload to your `eschmidt@nvidia.com` inbox per the original request). +We have a ready-made data package captured from a live PRIME-RL deployment on GB200. **It's not in this repo** (binary tensors at GB scale aren't appropriate to commit) β€” request access from `kavink@nvidia.com` and we'll share via the appropriate channel (NV S3 bucket, internal share, or direct upload). Package contents: