feat(RL/post-2389): Phase 3+4 v2 client — compile-target registry + multi-source slice planner#349
Draft
KavinKrishnan wants to merge 40 commits into
Draft
feat(RL/post-2389): Phase 3+4 v2 client — compile-target registry + multi-source slice planner#349KavinKrishnan wants to merge 40 commits into
KavinKrishnan wants to merge 40 commits into
Conversation
Presentation covering ModelExpress integration into RL post-training weight sync (refit) for NeMo RL, verl, and PRIME-RL. Includes HTML slideshow and standalone SVG diagrams for architecture, transfer flow, component stack, and framework comparison. Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Mirrors all 6 slides from the HTML presentation in plain markdown for easier viewing on GitHub and compatibility with Marp/Slidev. Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Training-side publisher registers updated model weights with NIXL and publishes metadata to the MX Server. Inference-side receiver discovers sources via ListSources, pulls weights via RDMA, and yields (name, tensor) pairs compatible with vLLM's load_weights(). Supports both all-at-once and layer-by-layer streaming patterns for PRIME-RL integration. Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
…hash mismatch) Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Tensor memory addresses don't change between optimizer steps, only values do. Calling register_memory every step accumulated descriptors, inflating the metadata blob from ~27 KB to 800+ KB and causing NIXL_ERR_NOT_ALLOWED on add_remote_agent. Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Allocates temporary GPU buffers matching the source's tensor layout, receives via NIXL RDMA, and yields (name, tensor) pairs in HF format. The caller's model.load_weights() handles name mapping and tensor fusion (e.g. HF q/k/v -> vLLM qkv_proj). Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
…ible with scratch buffers) Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
…ht reshaping Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
…rams Covers the MxCheckpointEngine design, Ray actor topology, and GB200 prototype results (10 steps, avg ~1.25s cross-node RDMA weight sync). Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
The section referenced recovery/ paths outside the ModelExpress repo that aren't accessible to external readers. Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
…og value - Rename §2 to frame MX as additive on verl's NIXL checkpoint engine - Document native nixl ring path positively; optional MX catalog + star - Add catalog benefits (balancing, multi-source, publish/retire, retention) - Fix RDMA READ source/destination wording; remove PRIME-RL references - Tone: prefer native nixl vs consider mx; align metrics cross-refs Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Companion to PRIMERL_MX_OVERVIEW.md (Path A). Documents an MX-shaped weight broadcast design for prime-rl that uses PI's NIXL transport as the data plane but exposes ModelExpress's traditional API surface (model-agnostic, server-mediated, scratch-buffer-default, cross-framework-portable) instead of PI's per-model conversion_specs + slot system. Key positioning: - Path A = strict overlay on PI's API. Smallest diff. In flight. - Path B = native MX shape. Larger diff. Staged design only. - Both ship as discriminator options on weight_broadcast.type (existing nixl + new mx coexist). Documents: - Why we'd consider B (per-model spec wall, KL drift inheritance, elastic shapes, cross-framework alignment). - File footprint (~600 LOC new, ~70 LOC modified). - Migration path (toml-only for users). - Pitch sequence for PI conversation. - Inflection points to pivot from A to B. Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Fix two diagrams in PRIMERL_MX_OVERVIEW.md that implied MX replaces PI #2326's entire control plane, when v0.1 is env-var-gated and only swaps SPG coordinator discovery. Component diagram: - Label every imported PI element "(PI, unchanged)" and every MX element "(MX overlay, env-var gated)" so reader can see the split. - Re-draw control-plane edges as dotted green (MX additions): publish_spg_coordinator (boot), mark_version_ready (per step), discover_spg_coordinator (boot), publish_rollout_source (optional). - Add the SPG 2-round all_gather_obj edge between trainer and rollout labeled as PI code — previously missing entirely, so readers could think MX alone wired up NIXL agents. - Data-plane edge labeled "trainer -> rollout recv buffer" to match PI's actual WRITE semantics instead of a generic bidirectional arrow. Timing diagram: - Wrap flow in three tinted bands: green boot-time MX discovery (the only v0.1 change) vs purple per-step SPG metadata rounds + RDMA WRITE (PI, unchanged). - Move discovery out of the per-step par block into a "once per run" boot-time block — register_coordinator / discover_spg_coordinator are init-time, not update-time ops. - Add the SPG 2-round all_gather_obj step that was missing: round 1 exchanges agent_meta + slot_layout + recv-buffer descriptors, round 2 exchanges per-slot xfer descriptors. - Relabel the RDMA step as "NIXL RDMA WRITE -> rollout's recv buffer" so the write direction is explicit. - Trigger for unpublish changed from vague "next iteration" to "async_level >= 1" — ties the mutability contract to the actual concurrency regime that needs it. - Add explicit legend calling out MX-added vs PI-unchanged. No design changes; docs-only clarification so the diagrams match what the overlay code actually does. Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Scenario A (PI NIXL direct refit on Qwen3-0.6B, 2 trainer ranks × 1 inference rank) completed all 20 RL training steps end-to-end on GB200 on April 24. Update the doc to reflect reality. Changes: - Status line: "Design complete ... Metrics below are TBD" → "Scenario A green end-to-end" with concrete numbers in the lead paragraph (596 MB/push, 310 slots, 100% success, draft PR link). - §2 observed per-step timing table: replace PI-reported 12-node projections with measured scenario A numbers (20/20 steps, 5.1 s avg, 596 MB bucket, 310 slots published, rank-0 writes all 310 / rank-1 writes 197). B and C cells remain "pending" until the next session flips the env vars. - §2 caveat added explaining that throughput comparison vs PI's prod numbers is distorted by our SDPA fallback (ARM64 image ships a flash_attn import stub; real kernels are a P1 follow-up). NIXL transfer itself is unaffected; step-time inflation is on the training-compute side. - §4.5 metrics matrix: populate scenario A column with measured values; mark B/C/D/E as "pending" rather than "TBD" to signal the scenarios are scoped + instrumented, just not yet run. - §4.6 results summary: rewrite from "to be written after the benchmark run" to a concrete list of what Scenario A proved (foundation validated, per-rank sharding-aware works as PI designed, overlay is structurally correct). Add pointers to the nine blocker fixes documented in OVERLAY_PR_EXECUTION_STATE.md. Keep the B/C/D/E expectations section framed as "next-session targets" so readers know what to expect the doc to grow into. No diagram changes; April 24 timing + component diagram fixes from 461be85 remain the current shape. Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Path A overlay scenarios A, B, and C all completed 20/20 training
steps on GB200. Update §2, §4.5, §4.6 with the measured numbers and
update §1's status line accordingly. Diagrams unchanged.
§2 (timing diagram + observed timing table):
- Replace pending B/C cells with measured values.
- Wire BW per trainer NIC: 7.82-8.84 GB/s (avg ~8.1) — exceeds
PI's reported ~7.5 GB/s prod target.
- Aggregate net BW: 35-39 GB/s rank 0 + rank 1 combined.
- Per-push breakdown (scenario C): convert 60-67 ms + post+wait
15-16 ms + barrier 1.2-1.6 ms ≈ 80 ms total for 596 MB.
- Add the pipeline-replication catalog state output (4 sources
incl. rollout-source-0-*) as the empirical proof the MX-side
protocol works.
§4.5 metrics matrix:
- All A/B/C cells now have measured values; D and E flagged as
deferred (with reasons).
- 4.5.2 throughput row gains the wire/net BW measurements.
§4.6 results summary rewritten:
- A: foundation validated, nine blockers resolved.
- B: MX-mediated discovery validated (single source_id shared
across all participants), data-path parity with A. Documents
the metadata_endpoint→nixl_metadata workaround we landed
during this session.
- C: pipeline-replication catalog entry confirmed; per-push wire/
net BW measured. Honestly notes the bandwidth-amplification
benefit isn't shown end-to-end yet (gated on PI-side dynamic
SPG world_size for elastic mid-run join).
- D and E: explicitly deferred with rationale (D needs a drift
reproducer to be valuable; E gates on dynamic SPG).
§1 status line: "scenarios A, B, and C all green on GB200" replaces
the "scenario A green" wording.
Net for the PR-on-PR: A and B are the strongest evidence (overlay
is additive without regressing data path). C's catalog entry plus
the measured per-push BW round out the picture. D and E are honest
follow-up axes.
Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Rebased onto current main (was 3 weeks stale; resolved one trivial
__all__ merge conflict in modelexpress/__init__.py).
Python — correctness fixes:
1. refit_receiver.poll_for_source: was hardcoding training_step=0
on the returned SourceRef and never filtering on min_step despite
advertising both in the docstring. ListSourcesResponse instances
carry only SourceInstanceRef (no extra_parameters), so the actual
training_step lives on SourceIdentity in the publisher's metadata.
Now do a per-candidate get_metadata() lookup, parse training_step
from SourceIdentity.extra_parameters, and skip candidates whose
step is below the threshold or unparseable. Cost: extra gRPC
round-trip per candidate; can be removed once training_step is
surfaced on SourceInstanceRef directly.
2. training_publisher.initialize(): training_framework was
hardcoded to "prime_rl" in _build_identity, which mislabeled
verl-published sources. Now a parameter on initialize() (default
"unknown" so callers know to set it explicitly).
3. training_publisher publish_weights / publish_layer mutual
exclusivity: publish_layer registers fresh tensors every call but
publish_weights caches via self._registered, so interleaving the
two paths could leave NIXL holding only the most-recently-
registered tensor set. New self._publish_mode tracks which path
is in use; either method raises if the other was already used on
this publisher.
4. refit_receiver._DTYPE_MAP: lifted to module scope (was rebuilt
per call inside receive_weights_scratch).
Docs — content fixes:
5. VERL_MX_OVERVIEW.md deployment-mode table: replaced ❌ / ✅
emoji markers with plain text per repo "no emojis in markdown"
guideline.
6. PRIMERL_MX_OVERVIEW.md §3.9: fixed duplicate "byte-exact
byte-exact" → "byte-exact".
7. MD040: annotated 15 bare ``` fences across MX_RL_OVERVIEW.md,
PRIMERL_MX_OVERVIEW.md, VERL_MX_OVERVIEW.md,
PRIMERL_MX_NATIVE_DESIGN.md, mx-rl-integration-slides.md as
```text where they were carrying plain prose / ASCII layout.
8. ASCII → mermaid:
- MX_RL_OVERVIEW.md §Architecture: ASCII trainer/server/inference
swimlane → sequenceDiagram.
- PRIMERL_MX_OVERVIEW.md §3.2 DAG buildup: 5-phase ASCII timeline
→ flowchart with one subgraph per phase, plus a per-phase
bandwidth table.
- PRIMERL_MX_OVERVIEW.md §3.9 before/after: naive-allgather vs
overlay per-rank flow → side-by-side flowchart.
- Slide-deck ASCII bottleneck-bar / 3-column architecture
intentionally retained: those are CSS-styled visual fallbacks
for the SVGs, not Markdown rendering targets. The misleading
"[ INSERT DIAGRAM: diagram-architecture.svg ]" placeholder text
above the architecture fallback was removed.
No proto / server-side changes. The poll_for_source fix is the
proto-level workaround documented in the CodeRabbit review; the
forward-looking fix (adding training_step directly to
SourceInstanceRef so we don't need the per-candidate get_metadata)
is a follow-up.
Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Adds NIXL_COMPRESSION_STUDY.md to help the NIXL nvCOMP compression team reproduce our RL weight-transfer payloads using our validated PRIME-RL and verl workflows. Three paths documented: 1. Pre-captured data (fastest) — pointer to our existing Qwen2.5-1.5B data package (model.safetensors + pre/post RL weights + deltas + KV cache, captured from live GB200 deployment). 2. End-to-end reproduction on GB200 via the PRIME-RL overlay (PR PrimeIntellect-ai/prime-rl#2343) — deploy scenario A, exec into trainer pod, capture state_dict + simulate one RL step + dump KV cache. Step-by-step with kubectl commands. 3. Reproduction via verl MxCheckpointEngine (PR ai-dynamo/modelexpress #252) — same tensor content, different transport path. Also covers: compression-relevant properties table, per-tensor layout for Qwen3-0.6B and Qwen2.5-1.5B, delta analysis notes (BF16 deltas mostly zero at RL learning rates; FP32 diffs are the meaningful analysis target), NIXL integration point for nvCOMP (transparent — compress/decompress at the NIXL layer, no MX or framework changes), and a model-size scaling table for larger captures. Signed-off-by: Kavin Krishnan <kavink@nvidia.com> Made-with: Cursor
The pre-captured Qwen2.5-1.5B data package referenced in Option 1 of NIXL_COMPRESSION_STUDY.md isn't in this repo (binary tensors at GB scale aren't appropriate to commit) and the path I had previously shown was an internal local checkout. Replace with explicit "request from kavink@nvidia.com" framing and call out the appropriate channels (NV S3, internal share, or direct upload to eschmidt@nvidia.com per the original ask). Add the total package size (~14 GB) so the NIXL team knows what to expect bandwidth-wise. Update the "larger models" cross-reference accordingly. Signed-off-by: Kavin Krishnan <kavink@nvidia.com> Made-with: Cursor
Adds the two scripts that produced the Qwen2.5-1.5B data package we
referenced in NIXL_COMPRESSION_STUDY.md so the NIXL team can reproduce
captures themselves on different models / sequence lengths / clusters
without going through us as a manual relay.
New files:
docs/RL/scripts/capture_weights_and_kv.py
Standalone — any HF model, any host (CPU or single GPU), no
cluster / RL framework needed. CLI flags for model, dtype,
device, output dir, weights/KV-only modes, KV seq_len.
docs/RL/scripts/capture_on_pod.py
Inside-a-running-RL-pod variant. Generalized vs the original
Qwen2.5-1.5B-only capture: --model, --out, --kv-seq-len, --lr
flags. Captures pre/post RL weights + simulated AdamW step
delta + KV cache in one pass. Produces the four-directory
layout (weights_pre_rl/, weights_post_rl/, weight_deltas/,
kv_cache/) we shipped to the compression team.
docs/RL/scripts/README.md
Quick reference for both scripts: when to use each, complete
CLI examples, output layout, the BF16-deltas-are-mostly-zero
note + FP32 analysis snippet, pointer back to the main study
doc.
Updated:
docs/RL/NIXL_COMPRESSION_STUDY.md
Option 2 now points at scripts/capture_on_pod.py with kubectl
cp + exec invocation instead of an inlined heredoc Python
block. Added Option-2-Step-4 ("standalone capture without a
running RL deployment") pointing at capture_weights_and_kv.py
for users who don't want to deploy the full overlay.
The original on-disk capture scripts in our internal recovery
directory are unchanged; this just publishes a generalized,
flag-driven version of each into the public docs tree.
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor
Two changes to NIXL_COMPRESSION_STUDY.md: 1. Add a component-view mermaid diagram at the top showing where the compression-target tensors actually live (RL refit edge between trainer and inference NIXL agents; KV cache edge between prefill and decode), with green nodes / edges marking the compression surface and purple marking RL-stack infrastructure that wouldn't change if nvCOMP slots into the NIXL layer transparently. 2. Drop GKE/cluster-specific assumptions. Previously Option 2 named a specific GKE node pool, namespace, registry, and tsh auth flow as prerequisites; now it just says "a GB200 cluster (ARM64) with at least 2 nodes, container runtime, RDMA-capable interconnect". The K8s manifests are flagged as examples that need light edits (ns, node selectors, registry, RDMA network annotations) per cluster. Hardcoded "kavin" namespace replaced with $NS=<your-namespace> throughout the kubectl commands so a copy-paste of the recipe works on any cluster. The capture flow itself was already cluster-agnostic — these edits just stop the doc reading like it's only reproducible on our exact GKE shape. Signed-off-by: Kavin Krishnan <kavink@nvidia.com> Made-with: Cursor
Removes both references to eschmidt@nvidia.com from NIXL_COMPRESSION_STUDY.md so the guide reads as a general team-facing doc rather than addressed at one inbox. Audience line now just says "NIXL compression team"; Option 1 channel list trims "direct upload to your eschmidt@nvidia.com inbox per the original request" down to "direct upload" — same channel options, no person-specific routing. Single contact for the data package remains kavink@nvidia.com. Signed-off-by: Kavin Krishnan <kavink@nvidia.com> Made-with: Cursor
…filter, tree fan-out
Adds the MX-side support for the NemoRL integration design (see
pensieve/RL/NemoRL/04_design_v2_moe_rank_to_rank.md). Built on top of the
existing MxTrainingPublisher / MxRefitReceiver as a Python-only shim, so
this lands without Rust server changes.
What's new:
* shape_descriptors: TensorDescriptorV2 + DTensor placement -> wire format
helpers. Handles Replicate / Shard / Partial; computes per-rank local
shard ranges; supports MoE expert axis with owned_expert_ids.
* nemo_rl_v2: MxV2TrainingPublisher / MxV2RefitReceiver / TrainerWorldLayout
/ V2SourceCandidate. Implements:
- rank-to-rank publish (each rank publishes its own local shard, no
allgather)
- same-rank-only routing (PrimeRL GB200 lesson: avoids cross-NIC subnet
writes that were causing NIXL_ERR_REMOTE_DISCONNECT)
- freshest-per-rank dedup by updated_at (avoids stale READY peers from
orchestrator restart cascades, which were causing NIXL_ERR_NOT_ALLOWED)
- tree fan-out via publish_self_as_source (TensorHub paper 2604.09107v1
pipeline replication)
- MoE expert coverage filter for receivers in EP layouts
- auto heartbeat via existing HeartbeatThread
v2 metadata transport (3-tier fallback):
1. SourceIdentity.extra_parameters via meta.identity (cleanest; requires
Rust server to populate the new identity field on
GetMetadataResponse)
2. Synthetic TensorDescriptor sidecar __mx_v2_meta__ with JSON in dtype
field (the path that works against the current server, which drops
SourceIdentity and most string fields when echoing WorkerMetadata)
3. WorkerMetadata.agent_name string-encoded marker (legacy fallback)
Proto change: added SourceIdentity identity = 5 to GetMetadataResponse.
Backward-compatible (older clients ignore). Python stubs regenerated.
Tests:
* 15 unit tests covering shape descriptor round-trip (Replicate, Shard,
MoE expert), expert codec, world-layout codec, picker filtering
(same-rank, min-version, mx_v2 marker, expert coverage, trainer
fallback), DAG fan-out, and agent_name fallback for legacy servers.
* scripts/v2_moe_e2e_demo.py: standalone GB200 cluster demo. Validated
on dynamo-gcp-dev-02 in kavin namespace: 4 ranks x 2 cycles, real
NIXL RDMA, same-rank routing, freshness dedup, MoE expert sharding,
sidecar transport. All 8 transfers correct.
Companion changes in NVIDIA/RL on branch kavink/mx_integration adopt
this client.
…lients Adds the SourceIdentity round-trip that v2 RL clients (NemoRL update_weights_via_mx, prime-rl PR #2389 follow-ups) need to read framework-level state from extra_parameters. Changes: * metadata_backend.rs: ModelMetadataRecord gains an Option<SourceIdentity>. Old records (pre-v2 storage) leave it None. * p2p_service.rs::get_metadata: populates GetMetadataResponse.identity from the record's identity field (the new proto field added in the preceding commit on the client branch). * metadata_backend/redis.rs: SourceAttributesJson gains an extra_parameters HashMap (with #[serde(default)] for back-compat) and a to_source_identity() method that reconstructs the full SourceIdentity from the stored attributes hash. Old records without extra_parameters deserialize cleanly. * metadata_backend/kubernetes.rs: identity stays None for now; CRD schema bump is a separate change. v2 clients fall back to the sidecar transport (synthetic TensorDescriptor) until then. * state.rs + p2p_service.rs test fixtures: identity: None added to match the new struct shape. This change is forward-compatible: pre-existing records read into ModelMetadataRecord with identity=None and old clients ignore the new GetMetadataResponse.identity field. Pairs with the proto change in commit 97c0e78 that added SourceIdentity identity = 5 to GetMetadataResponse, and with the v2 NemoRL client which already prefers identity.extra_parameters when available and falls back to the synthetic-tensor-descriptor sidecar otherwise. Build/deploy: requires rebuilding modelexpress-server image and redeploying. The current image at nvcr.io/nvidian/dynamo-dev/modelexpress-server:latest predates these fields; the v2 RL prototype works against it via the sidecar transport.
Comprehensive companion doc for the NemoRL upstream PR. Covers:
* motivation (PrimeRL GB200 lessons, TensorHub paper, Composer 2)
* the four design pillars (rank-to-rank publish, tree scale-out, MoE
expert filtering, explicit shape registry)
* full Python API surface with worked code snippets
* file inventory across kavink/nemo_rl_moe (MX) and
kavink/mx_integration (NemoRL)
* the three-tier metadata transport workaround for the running server
(SourceIdentity → __mx_v2_meta__ sidecar → agent_name fallback)
* what was tested:
- 15/15 unit tests passing (no GPU/NIXL)
- live cluster gRPC smoke
- live E2E on GB200, toy scale (4×2 cycles, byte-correct)
- live E2E on GB200, production scale (4×1.6 GB Qwen3-30B-A3B-shaped
in 11–16 ms per transfer)
- arm64 Docker overlay built + smoke-tested
- explicit list of what was NOT exercised
* server-side patch path (rebuild image to land identity round-trip)
* deployment recipe (config.yaml + expected log lines)
* roadmap (async refit, Megatron, SGLang, dirty-experts, cross-DC,
drain semantics)
Self-contained for upstream review; replaces no existing doc.
The doc is intended to be self-contained for upstream review; the internal-pensieve cross-references would dangle for outside readers. Replaced with public references (PR links + sister docs already in docs/RL/).
…ensors Two bugs caught by a real-DTensor (not faked) e2e test on GB200, where the publisher wraps a torch.distributed.tensor.DTensor instead of a stand-in object: 1. shape_descriptors.describe_tensor: the SHARD branch assumed tensor.shape was the LOCAL view, which is true for plain tensors but FALSE for real DTensors — DTensor.shape is the global, un-sharded shape. With FSDP=4 and HIDDEN=1024, this caused global_shape to be computed as 1024*4=4096 (wrong) and local_shard_range as rank*1024 instead of rank*256. Fix: detect torch.distributed.tensor.DTensor and use tensor.to_local().shape[dim] as local_extent. Plain-tensor and stand-in-DTensor paths are unchanged. 2. nemo_rl_v2.MxV2TrainingPublisher: shape_registry was only being transmitted via SourceIdentity.extra_parameters, which the running Rust server drops. Sidecar transport (the synthetic __mx_v2_meta__ TensorDescriptor) carries the v2 marker but not the registry. Receivers had no per-tensor placement info. Fix: include shape_registry in the sidecar JSON (nested-string ok, JSON encoder handles it). Receiver-side parsing already handles extra["shape_registry"]. Validated end-to-end on prime-rl-nixl-mx-trainer-0 (GB200, 4 GPUs): 4 ranks × 2 refit cycles, real torch.distributed mesh + Shard(0) placement. Registry now reports correct global=(1024,2048) + local_range=(rank*256, (rank+1)*256) per rank. All 8 transfers all_elem_match=True (every byte of every received local shard equals its rank+version sentinel). Companion: scripts/v2_dtensor_e2e_demo.py exercises this codepath. 15/15 unit tests still pass (their stand-in fake-DTensor hits the plain-tensor branch of describe_tensor).
Companion to v2_moe_e2e_demo.py that exercises the codepath that DTensorPolicyWorker.stream_weights_via_mx uses in production: real torch.distributed.tensor.DTensor on a (WORLD_SIZE,) mesh + Shard(0) placement, MxV2TrainingPublisher.add_tensor reading .placements off the DTensor, NIXL register on the local view, sidecar transport round-trip, byte-level correctness on the same-rank pull. Asserts: * registry's global_shape == DTensor.shape (un-sharded) * registry's local_shard_range == (rank*chunk, (rank+1)*chunk) * received local shard byte-matches every-cell sentinel
Documents the just-validated codepath: 4 ranks × 2 cycles, real torch.distributed.tensor.DTensor on a (WORLD_SIZE,) mesh + Shard(0) placement. All 8 transfers byte-correct (all_elem_match=True), and the receiver's reconstructed shape registry reports correct global=(1024,2048) + per-rank local_range=(rank*256, (rank+1)*256). Also points out the two real bugs the test caught (DTensor.shape semantics, sidecar shape_registry) and renumbers the existing §8.5+ sections accordingly. The §8.7 'NOT validated' entry on DTensorPolicyWorker.stream_weights_via_mx is updated — the v2 protocol mechanics are now exercised; only NemoRL-specific outer glue (HF state_dict walk, MoE expert-name heuristic, cpu_offload lifecycle) remains as integration-level work.
#295) Signed-off-by: John Thomson <jothomson@nvidia.com>
… in loopback Both v2 demo scripts (v2_moe_e2e_demo and v2_dtensor_e2e_demo) ran on GB200 via UCX's intra-node cuda_ipc fast path, which silently tolerates malformed prep_xfer_dlist entries. That's why neither demo caught the v2 sidecar (__mx_v2_meta__, addr=0, size=0) leak that PR #295 (commit 53c69ec) fixed — the bad descriptor only trips UCX's validator on real cross-node rc_mlx5 / cuda_copy paths, which is what jthomson04 hit on GB300 RoCE during Dynamo bring-up. FORCE_RDMA=1 sets UCX_TLS=self,sm,rc_mlx5,cuda_copy,tcp (omitting cuda_ipc), routing intra-node demos through the same strict validator that cross-node would use. Pre-deploy runs should set this so future descriptor-list bugs of this shape surface in loopback rather than waiting for a real cross-host deployment. Usage: FORCE_RDMA=1 WORLD_SIZE=4 python3 v2_moe_e2e_demo.py Background: pensieve/RL/NemoRL/07_dynamo_handoff_2026_05_18.md §"The real root cause: v2 sidecar leaking into RDMA descriptor list".
…source signature After rebasing kavink/nemo_rl_moe onto current main, two API skew issues surfaced that the rebase mechanical-merge couldn't see: 1. `nemo_rl_v2.py` imports `HeartbeatThread` from the flat `modelexpress.heartbeat` path, but main moved that module into the `modelexpress.metadata` subpackage. Crashed with `ModuleNotFoundError: No module named 'modelexpress.heartbeat'` on the trainer when DTensorPolicyWorker.stream_weights_via_mx first imports nemo_rl_v2. Fix: import from `.metadata.heartbeat`. 2. `MxRefitReceiver.receive_weights_scratch` passes `coalesce_transfers=False` to `NixlTransferManager.receive_from_source`. Main removed that parameter (coalescing is now gated entirely by `MX_POOL_REG`), so the call raises `TypeError: receive_from_source() got an unexpected keyword argument 'coalesce_transfers'`. Worker logs showed the receiver looping with "[mx-poller] refit failed for version N; will retry" while the trainer silently proceeded on stale weights (the dynamo extension acks the refit RPC immediately and runs the actual receive in a background poller). Fix: drop the obsolete kwarg. E2E validated on Qwen3-4B-Thinking + Dynamo vLLM v1 GRPO smoke, both refit cycles complete with clean version transitions: step=1: RDMA transfer complete: 8.82 GB, 399 tensors, 0.20s, 357.5 Gbps step=2: RDMA transfer complete: 8.82 GB, 399 tensors, 0.18s, 384.4 Gbps [mx-poller] refit OK to version 1 [mx-poller] refit OK to version 2 Signed-off-by: John Thomson <jothomson@nvidia.com>
…ulti-source slice planner Implements Phases 3a, 3b, and 4 of the post-PR-#2389 RFC (KavinKrishnan/prime-rl: kavink/post-2389-kernel-compile-plan, RFC §3.3 and §5): Phase 3a — extend TensorDescriptorV2 with `compile_target` and `compile_metadata`. Default `compile_target=hf_raw` keeps the wire byte-compat with pre-Phase-3 candidates: encode_registry omits the field when it's the default. Adds a `compile_target_matches` helper for receiver-side filtering (whitelist + required-metadata-subset). New canonical string constants: hf_raw, vllm_fused, deep_gemm_fp8, cutlass_fp8, trtllm. Phase 3b — extend MxV2RefitReceiver.discover_v2_sources with `compile_target_filter` and `required_compile_metadata`. Candidates without a v2 registry are rejected when either is set (we can't certify bytes blindly). Candidates with mixed compile targets are rejected if any tensor's target is outside the allowed set. V2SourceCandidate now exposes `compile_targets: frozenset[str]` for caller introspection. Phase 4 — multi-source slice discovery for mixed trainer/inference TP: - New types: TargetTPLayout, SliceSource, SliceCoveragePlan. - New method MxV2RefitReceiver.discover_v2_sources_for_slice(target_layout=…): walks all v2 candidates per tensor, intersects their local_shard_range with the receiver's requested slice, emits a minimal SliceSource list. Detects coverage gaps + shard_axis mismatches and surfaces them in plan.missing. - New method MxV2RefitReceiver.receive_via_plan(plan): orchestrates the scratch RDMA pulls per contributing candidate and stitches results with torch.cat along the shard axis. v0 issues one full scratch fetch per candidate; byte-level partial RDMA is a follow-up (RFC §5 Phase 4.5). Unit tests (33 total, all green): - Phase 3a: compile_target default, round-trip, wire-omission-when-default, matches helper with whitelist and required-metadata-subset (6 tests). - Phase 3b: filter accepts matching / unset admits all / rejects when no registry / required-metadata pinning works (4 tests). - Phase 4: planner covers within-shard slice, planner stitches cross-shard slice, planner picks freshest for REPLICATE, planner flags coverage gap, planner flags shard_axis mismatch (5 tests). receive_via_plan stitches two sources byte-exactly, passes through single source, refuses uncovered plan (3 tests). Module __init__.py re-exports: TargetTPLayout, SliceCoveragePlan, SliceSource, V2SourceCandidate, TensorDescriptorV2, compile_target_matches, all COMPILE_TARGET_* constants. Does not modify the existing single-source receive paths or the v1 fat clients, so existing demos (v2_dtensor_e2e_demo.py, v2_moe_e2e_demo.py) and downstream consumers (NemoRL kavink/mx_integration, jthomson04/RL Dynamo path) continue to compile + run unchanged.
4 tasks
…_tensor
Phase-3 graduation glue. After this commit, callers of MxV2TrainingPublisher
can tag each tensor with its kernel layout at publish time, and the tag
flows end-to-end into the v2 sidecar TensorDescriptor's compile_target +
compile_metadata fields. Receivers can then filter via
discover_v2_sources(compile_target_filter=..., required_compile_metadata=...)
without needing to inspect the raw bytes.
Callers will typically pass these from their trainer-side conversion
registry. For prime-rl, that's
prime_rl.trainer.models.conversions.ConversionEntry.{compile_target,
compile_metadata} (Phase 3 trainer-side PR on
KavinKrishnan/prime-rl#kavink/post-2389-conversion-registry-extensions).
Defaults preserve back-compat: compile_target='hf_raw', compile_metadata={}.
Existing callers that don't pass either kwarg get the unchanged
behaviour and the descriptor's wire form is byte-identical to before
(encode_registry omits these fields when at defaults).
Tests: 35/35 green. Two new tests:
- test_phase3_add_tensor_threads_compile_target: 3-tensor mix
(lm_head=hf_raw default, gate_proj=cutlass_fp8 with metadata,
experts=deep_gemm_fp8 with block_size metadata) all flow correctly
into the publisher's internal registry.
- test_phase3_add_tensor_compile_target_survives_encode_decode:
end-to-end wire round-trip via encode_registry + decode_registry.
Doesn't change any existing call sites; this is the API extension that
makes the Phase 3 PR consumable.
…dapter for vLLM native RL APIs
Two additive surfaces on top of the existing Phase 2/3/4 work.
----------------------------------------------------------------------
1. Transfer-metrics instrumentation
----------------------------------------------------------------------
New TransferStats dataclass on the receiver side captures per-receive
metrics (bytes, tensors, elapsed, bandwidth_gbps, discovery_seconds,
path, training_step, source_worker_rank) in structured form. All three
MxRefitReceiver receive paths populate it:
- receive_weights → path="pre_registered"
- receive_weights_scratch → path="scratch"
- receive_weights_from_metadata → path="from_metadata"
Exposed on the receiver as `last_stats` (latest call) and `history`
(full per-call list). The v2 layer additionally tracks
`_last_discovery_seconds` (control-plane round-trip time, distinct from
data-plane RDMA time) so benchmarks can compare them.
The numbers were already produced inline as log lines; this commit
captures them as queryable state so benchmark harnesses don't have to
parse logs.
----------------------------------------------------------------------
2. MxWeightTransferEngine — vLLM native RL APIs adapter (Option A from
pensieve/RL/PrimeRL/10_*.md)
----------------------------------------------------------------------
New module modelexpress/vllm_weight_transfer.py implements the
WeightTransferEngine abstract base introduced in the 2026-05-28 vLLM
Native RL APIs blog. RL frameworks pick it up via:
import modelexpress.vllm_weight_transfer # registers "mx_nixl"
llm = LLM(..., weight_transfer_config=WeightTransferConfig(backend="mx_nixl"))
Three info dataclasses on the wire:
MxInitInfo — mx_server_url, model_name, worker_rank,
agent_name, device_id, publish_self_as_replica
MxUpdateInfo — version, target_tp_layout (Phase 4),
compile_target_filter (Phase 3b),
required_compile_metadata (Phase 3b),
same_rank_only + dedup_freshest_per_rank (Phase 2)
MxTrainerSendArgs — publisher, version, compile_target,
compile_metadata, per-tensor expert metadata
Dispatch logic in receive_weights:
- target_tp_layout=None → matched-TP fast path: discover_v2_sources +
pick_best_source + receive_from (single source, same-rank)
- target_tp_layout=set → Phase-4 path: discover_v2_sources_for_slice
+ receive_via_plan (multi-source stitched)
Both paths apply the Phase 3 compile_target_filter +
required_compile_metadata at discovery time, refusing incompatible
sources BEFORE spending any RDMA cycles.
After a successful receive, optionally calls publish_self_as_source for
tree fan-out / pipeline replication (TensorHub pattern) — controlled by
MxInitInfo.publish_self_as_replica (default True). Failure of this
best-effort optimization does NOT propagate.
Trainer-side classmethod trainer_send_weights threads the compile_target
+ compile_metadata + per-tensor expert metadata from MxTrainerSendArgs
into each MxV2TrainingPublisher.add_tensor() call, then publishes once
with the version.
Registration with vLLM is via WeightTransferEngineFactory at import
time, with a try/except so the module is import-safe in environments
without vLLM (tests, publisher-only nodes, benchmark harnesses). The
MX_WEIGHT_TRANSFER_AUTOREGISTER=0 env var disables auto-registration
even when vLLM IS installed.
Metrics surface exposed for benchmarks:
engine.last_transfer_stats — TransferStats from most recent receive
engine.transfer_history — full per-call history
engine.last_discovery_seconds — most recent control-plane round-trip
----------------------------------------------------------------------
Tests
----------------------------------------------------------------------
14 new unit tests in test_vllm_weight_transfer.py (all green). Categories:
- construction (with + without init_info, error before init)
- matched-TP fast path (yielded tensors reach load_weights)
- mixed-TP Phase 4 path (slice plan built + stitched)
- uncovered-plan rejection (Phase 4 fully_covered=False raises)
- no-source-matches-filter rejection (Phase 3b fast path)
- kwarg passthrough (compile_target_filter, required_compile_metadata,
same_rank_only, dedup_freshest_per_rank all reach the receiver)
- publish_self_as_replica triggered post-receive
- publish failure swallowed (tree fan-out is best-effort)
- trainer_send_weights threads compile_target/compile_metadata/expert
metadata to add_tensor + publishes with version
- metrics surface returns sensible Nones / empties pre-init
- factory contract: init_info_cls + update_info_cls class attrs
All 49 v2 tests still green (35 prior + 14 new). Companion design doc
at pensieve/RL/PrimeRL/10_mx_weight_transfer_engine_design.md.
…rk suite
Adds the three-scenario transport-layer benchmark we need to demonstrate
the MX v2 + MxWeightTransferEngine integration on real hardware:
- elastic_scale — N receivers join staggered; measures cold-start
join latency, per-cycle Gbps, control-plane vs
data-plane latency split
- compile_target — three concurrent receivers (matched filter,
mismatched filter, no filter) prove the Phase 3b
safety net refuses incompatible bytes BEFORE RDMA
and that the no-filter path stays back-compatible
- tree_fanout — same as elastic but receivers also
publish_self_as_source; measures trainer egress
vs total delivered (fanout factor)
The harness drives both ends through the new MxWeightTransferEngine
adapter, so the numbers reflect what RL frameworks will see going
through vLLM's native WeightTransferEngine interface — not a special
private API.
Two run modes:
- --mode=live (default) launches the trainer + receivers as
subprocesses against a real MX server + NIXL data plane. Needs
CUDA + NIXL + a reachable mx-server URL.
- --mode=cpu is a stubbed orchestrator-only smoke. Exercises
result aggregation + summary table generation without touching
MX or torch.cuda. Used in CI and for local development.
Output: human-readable summary table on stdout + machine-readable
JSON via --output (schema documented in benchmarks/README.md).
Companion changes:
- MxV2RefitReceiver gains receive_from_scratch(candidate): wraps
receive_weights_scratch so callers without pre-registered model
buffers (the benchmark, cold-start vLLM workers) can still
drive the matched-TP path
- MxWeightTransferEngine switches the matched-TP path to use
receive_from_scratch by default, matching the comment that was
already in the source: scratch mode is the right cold-start
behavior and matches Anyscale's RDT plugin pattern. When vLLM
exposes register_destinations (proposed extension §5.1 in the
design doc), this can switch to the zero-copy path.
- The two engine tests that monkeypatched receive_from are updated
to monkeypatch receive_from_scratch instead
Tests:
- 9 new orchestrator unit tests cover bandwidth math, trainer-egress
accounting (both with and without fan-out), compile_target verdict
derivation, percentile helpers, summary-table conditional rendering,
CLI arg parsing, and an end-to-end --mode=cpu CLI run that writes
JSON to a tmp path
- All 58 v2 unit tests still green (35 shape/picker + 14 engine + 9
bench)
Companion doc at pensieve/RL/PrimeRL/11_benchmark_results.md with the
methodology + acceptance criteria for each scenario; numbers tables
left as PENDING until the cluster Teleport access is refreshed.
Lets us go from "Teleport auth refreshed" → "three JSON result files
in hand" with one command:
./run_cluster_bench.sh
The Job manifest (k8s/bench-elastic.yaml) runs all three scenarios in
sequence inside one pod in the kavin namespace. The pod uses the
existing prime-rl image (any image with modelexpress installed); the
harness lives at modelexpress/benchmarks/bench_elastic_scaling.py
inside the image, so no separate image build is needed.
The driver script (run_cluster_bench.sh):
- cleans up prior Job runs
- applies the manifest
- waits for the pod to start (up to 2 min)
- optionally tails logs (--watch)
- waits for Job completion (30 min cap)
- kubectl cp's /results/ out to ./results-<timestamp>/
- prints a per-scenario summary
Resource request is 5 GPUs (1 trainer + 4 receivers for scenarios 1
and 3); change nvidia.com/gpu in the manifest if your namespace has
different quota. Tolerations + nodeSelector target the GB200 pool by
default; comment them out or replace for other environments.
When run, results files map directly to the tables in
pensieve/RL/PrimeRL/11_benchmark_results.md — paste the per-receiver
rows from the JSON files into the matching sections.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Implements Phases 3a, 3b, and 4 of the post-PR-#2389 RFC. Draft because it depends on the upstream RFC being agreed before we ask MX team to review; opening early so the v2 client surface is visible alongside upstream PrimeIntellect-ai/prime-rl#2389.
Phase 3a —
compile_target+compile_metadataonTensorDescriptorV2hf_raw/{}; encoder omits them from the wire when default so existing payloads are byte-identical.COMPILE_TARGET_HF_RAW,_VLLM_FUSED,_DEEPGEMM_FP8,_CUTLASS_FP8,_TRTLLM.compile_target_matches(descriptor, *, allowed_targets, required_metadata=None)for receiver-side filtering with whitelist + required-metadata-subset semantics.Phase 3b —
compile_target_filteronMxV2RefitReceiver.discover_v2_sourcescompile_target_filter(whitelist set) andrequired_compile_metadata(dict that must be a subset of every tensor'scompile_metadata).V2SourceCandidateexposescompile_targets: frozenset[str]for caller introspection.Phase 4 — Multi-source slice discovery for mixed trainer/inference TP
TargetTPLayout,SliceSource,SliceCoveragePlan.MxV2RefitReceiver.discover_v2_sources_for_slice(target_layout=…): walks v2 candidates per tensor, intersects each publisher'slocal_shard_rangeagainst the receiver's requested slice, emits the minimal candidate set covering it. Surfaces coverage gaps +shard_axismismatches inplan.missing.MxV2RefitReceiver.receive_via_plan(plan): orchestrates the per-candidate scratch RDMA pulls and stitches results viatorch.catalong the shard axis. v0 issues one full scratch fetch per contributing candidate; byte-level partial RDMA is a Phase-4.5 follow-up (RFC §5).Tests
33/33 green (29 new):
test_v2_shape_registry.py: default value, round-trip, wire-omission, matches helper whitelist + required-metadata behaviour (6 new tests).test_v2_source_picker.py: filter accepts matching, unset admits all, rejects when no registry, required-metadata pinning (4 new tests).test_v2_source_picker.py: planner covers within-shard slice, planner stitches cross-shard slice, planner picks freshest REPLICATE, planner flags coverage gap, planner flags shard_axis mismatch (5 new).receive_via_planstitches two sources byte-exactly, passes through single source, refuses uncovered plan (3 new). Plus matching exports in__init__.py.Compat
Existing demos (
v2_dtensor_e2e_demo.py,v2_moe_e2e_demo.py) and downstream consumers (NemoRLkavink/mx_integration, jthomson04/RL Dynamo path) compile + run unchanged — every new arg is optional and every new field has a backwards-compatible default.Test plan
python -m pytest modelexpress_client/python/tests/