Skip to content

feat(RL/post-2389): Phase 3+4 v2 client — compile-target registry + multi-source slice planner#349

Draft
KavinKrishnan wants to merge 40 commits into
mainfrom
kavink/post-2389-phase3-4
Draft

feat(RL/post-2389): Phase 3+4 v2 client — compile-target registry + multi-source slice planner#349
KavinKrishnan wants to merge 40 commits into
mainfrom
kavink/post-2389-phase3-4

Conversation

@KavinKrishnan
Copy link
Copy Markdown
Contributor

Summary

Implements Phases 3a, 3b, and 4 of the post-PR-#2389 RFC. Draft because it depends on the upstream RFC being agreed before we ask MX team to review; opening early so the v2 client surface is visible alongside upstream PrimeIntellect-ai/prime-rl#2389.

Phase 3a — compile_target + compile_metadata on TensorDescriptorV2

  • New fields default to hf_raw / {}; encoder omits them from the wire when default so existing payloads are byte-identical.
  • New constants: COMPILE_TARGET_HF_RAW, _VLLM_FUSED, _DEEPGEMM_FP8, _CUTLASS_FP8, _TRTLLM.
  • New helper compile_target_matches(descriptor, *, allowed_targets, required_metadata=None) for receiver-side filtering with whitelist + required-metadata-subset semantics.

Phase 3b — compile_target_filter on MxV2RefitReceiver.discover_v2_sources

  • New kwargs compile_target_filter (whitelist set) and required_compile_metadata (dict that must be a subset of every tensor's compile_metadata).
  • Candidates with no v2 registry are rejected when either filter is set — bytes are unknowable.
  • Candidates with mixed compile targets are rejected if any tensor's target falls outside the allowed set (no partial-trust).
  • V2SourceCandidate exposes compile_targets: frozenset[str] for caller introspection.

Phase 4 — Multi-source slice discovery for mixed trainer/inference TP

  • New types: TargetTPLayout, SliceSource, SliceCoveragePlan.
  • New method MxV2RefitReceiver.discover_v2_sources_for_slice(target_layout=…): walks v2 candidates per tensor, intersects each publisher's local_shard_range against the receiver's requested slice, emits the minimal candidate set covering it. Surfaces coverage gaps + shard_axis mismatches in plan.missing.
  • New method MxV2RefitReceiver.receive_via_plan(plan): orchestrates the per-candidate scratch RDMA pulls and stitches results via torch.cat along the shard axis. v0 issues one full scratch fetch per contributing candidate; byte-level partial RDMA is a Phase-4.5 follow-up (RFC §5).

Tests

33/33 green (29 new):

  • Phase 3a — test_v2_shape_registry.py: default value, round-trip, wire-omission, matches helper whitelist + required-metadata behaviour (6 new tests).
  • Phase 3b — test_v2_source_picker.py: filter accepts matching, unset admits all, rejects when no registry, required-metadata pinning (4 new tests).
  • Phase 4 — test_v2_source_picker.py: planner covers within-shard slice, planner stitches cross-shard slice, planner picks freshest REPLICATE, planner flags coverage gap, planner flags shard_axis mismatch (5 new). receive_via_plan stitches two sources byte-exactly, passes through single source, refuses uncovered plan (3 new). Plus matching exports in __init__.py.

Compat

Existing demos (v2_dtensor_e2e_demo.py, v2_moe_e2e_demo.py) and downstream consumers (NemoRL kavink/mx_integration, jthomson04/RL Dynamo path) compile + run unchanged — every new arg is optional and every new field has a backwards-compatible default.

Test plan

  • 33/33 unit tests green via python -m pytest modelexpress_client/python/tests/
  • Cluster validation (NeMo-RL refit on GB200 — separate run; image rebuild required)
  • Cluster validation (PrimeRL post-#2389 refit on GB200 — sequencing PR after upstream merge)

Presentation covering ModelExpress integration into RL post-training
weight sync (refit) for NeMo RL, verl, and PRIME-RL. Includes HTML
slideshow and standalone SVG diagrams for architecture, transfer flow,
component stack, and framework comparison.

Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Mirrors all 6 slides from the HTML presentation in plain markdown
for easier viewing on GitHub and compatibility with Marp/Slidev.

Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Training-side publisher registers updated model weights with NIXL and
publishes metadata to the MX Server. Inference-side receiver discovers
sources via ListSources, pulls weights via RDMA, and yields (name, tensor)
pairs compatible with vLLM's load_weights(). Supports both all-at-once
and layer-by-layer streaming patterns for PRIME-RL integration.

Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
…hash mismatch)

Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Tensor memory addresses don't change between optimizer steps, only
values do.  Calling register_memory every step accumulated descriptors,
inflating the metadata blob from ~27 KB to 800+ KB and causing
NIXL_ERR_NOT_ALLOWED on add_remote_agent.

Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Allocates temporary GPU buffers matching the source's tensor layout,
receives via NIXL RDMA, and yields (name, tensor) pairs in HF format.
The caller's model.load_weights() handles name mapping and tensor
fusion (e.g. HF q/k/v -> vLLM qkv_proj).

Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
…ible with scratch buffers)

Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
…ht reshaping

Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
…rams

Covers the MxCheckpointEngine design, Ray actor topology, and GB200
prototype results (10 steps, avg ~1.25s cross-node RDMA weight sync).

Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
The section referenced recovery/ paths outside the ModelExpress repo that
aren't accessible to external readers.

Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
…og value

- Rename §2 to frame MX as additive on verl's NIXL checkpoint engine
- Document native nixl ring path positively; optional MX catalog + star
- Add catalog benefits (balancing, multi-source, publish/retire, retention)
- Fix RDMA READ source/destination wording; remove PRIME-RL references
- Tone: prefer native nixl vs consider mx; align metrics cross-refs

Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Companion to PRIMERL_MX_OVERVIEW.md (Path A). Documents an MX-shaped
weight broadcast design for prime-rl that uses PI's NIXL transport as
the data plane but exposes ModelExpress's traditional API surface
(model-agnostic, server-mediated, scratch-buffer-default,
cross-framework-portable) instead of PI's per-model conversion_specs +
slot system.

Key positioning:
- Path A = strict overlay on PI's API. Smallest diff. In flight.
- Path B = native MX shape. Larger diff. Staged design only.
- Both ship as discriminator options on weight_broadcast.type
  (existing nixl + new mx coexist).

Documents:
- Why we'd consider B (per-model spec wall, KL drift inheritance,
  elastic shapes, cross-framework alignment).
- File footprint (~600 LOC new, ~70 LOC modified).
- Migration path (toml-only for users).
- Pitch sequence for PI conversation.
- Inflection points to pivot from A to B.

Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Fix two diagrams in PRIMERL_MX_OVERVIEW.md that implied MX replaces
PI #2326's entire control plane, when v0.1 is env-var-gated and only
swaps SPG coordinator discovery.

Component diagram:
- Label every imported PI element "(PI, unchanged)" and every MX
  element "(MX overlay, env-var gated)" so reader can see the split.
- Re-draw control-plane edges as dotted green (MX additions):
  publish_spg_coordinator (boot), mark_version_ready (per step),
  discover_spg_coordinator (boot), publish_rollout_source (optional).
- Add the SPG 2-round all_gather_obj edge between trainer and rollout
  labeled as PI code — previously missing entirely, so readers could
  think MX alone wired up NIXL agents.
- Data-plane edge labeled "trainer -> rollout recv buffer" to match
  PI's actual WRITE semantics instead of a generic bidirectional arrow.

Timing diagram:
- Wrap flow in three tinted bands: green boot-time MX discovery
  (the only v0.1 change) vs purple per-step SPG metadata rounds +
  RDMA WRITE (PI, unchanged).
- Move discovery out of the per-step par block into a "once per run"
  boot-time block — register_coordinator / discover_spg_coordinator
  are init-time, not update-time ops.
- Add the SPG 2-round all_gather_obj step that was missing: round 1
  exchanges agent_meta + slot_layout + recv-buffer descriptors,
  round 2 exchanges per-slot xfer descriptors.
- Relabel the RDMA step as "NIXL RDMA WRITE -> rollout's recv buffer"
  so the write direction is explicit.
- Trigger for unpublish changed from vague "next iteration" to
  "async_level >= 1" — ties the mutability contract to the actual
  concurrency regime that needs it.
- Add explicit legend calling out MX-added vs PI-unchanged.

No design changes; docs-only clarification so the diagrams match
what the overlay code actually does.

Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Scenario A (PI NIXL direct refit on Qwen3-0.6B, 2 trainer ranks ×
1 inference rank) completed all 20 RL training steps end-to-end
on GB200 on April 24. Update the doc to reflect reality.

Changes:

- Status line: "Design complete ... Metrics below are TBD" →
  "Scenario A green end-to-end" with concrete numbers in the lead
  paragraph (596 MB/push, 310 slots, 100% success, draft PR link).

- §2 observed per-step timing table: replace PI-reported 12-node
  projections with measured scenario A numbers (20/20 steps, 5.1 s
  avg, 596 MB bucket, 310 slots published, rank-0 writes all 310 /
  rank-1 writes 197). B and C cells remain "pending" until the next
  session flips the env vars.

- §2 caveat added explaining that throughput comparison vs PI's
  prod numbers is distorted by our SDPA fallback (ARM64 image
  ships a flash_attn import stub; real kernels are a P1 follow-up).
  NIXL transfer itself is unaffected; step-time inflation is on
  the training-compute side.

- §4.5 metrics matrix: populate scenario A column with measured
  values; mark B/C/D/E as "pending" rather than "TBD" to signal
  the scenarios are scoped + instrumented, just not yet run.

- §4.6 results summary: rewrite from "to be written after the
  benchmark run" to a concrete list of what Scenario A proved
  (foundation validated, per-rank sharding-aware works as PI
  designed, overlay is structurally correct). Add pointers to
  the nine blocker fixes documented in OVERLAY_PR_EXECUTION_STATE.md.
  Keep the B/C/D/E expectations section framed as "next-session
  targets" so readers know what to expect the doc to grow into.

No diagram changes; April 24 timing + component diagram fixes
from 461be85 remain the current shape.

Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Path A overlay scenarios A, B, and C all completed 20/20 training
steps on GB200. Update §2, §4.5, §4.6 with the measured numbers and
update §1's status line accordingly. Diagrams unchanged.

§2 (timing diagram + observed timing table):
  - Replace pending B/C cells with measured values.
  - Wire BW per trainer NIC: 7.82-8.84 GB/s (avg ~8.1) — exceeds
    PI's reported ~7.5 GB/s prod target.
  - Aggregate net BW: 35-39 GB/s rank 0 + rank 1 combined.
  - Per-push breakdown (scenario C): convert 60-67 ms + post+wait
    15-16 ms + barrier 1.2-1.6 ms ≈ 80 ms total for 596 MB.
  - Add the pipeline-replication catalog state output (4 sources
    incl. rollout-source-0-*) as the empirical proof the MX-side
    protocol works.

§4.5 metrics matrix:
  - All A/B/C cells now have measured values; D and E flagged as
    deferred (with reasons).
  - 4.5.2 throughput row gains the wire/net BW measurements.

§4.6 results summary rewritten:
  - A: foundation validated, nine blockers resolved.
  - B: MX-mediated discovery validated (single source_id shared
    across all participants), data-path parity with A. Documents
    the metadata_endpoint→nixl_metadata workaround we landed
    during this session.
  - C: pipeline-replication catalog entry confirmed; per-push wire/
    net BW measured. Honestly notes the bandwidth-amplification
    benefit isn't shown end-to-end yet (gated on PI-side dynamic
    SPG world_size for elastic mid-run join).
  - D and E: explicitly deferred with rationale (D needs a drift
    reproducer to be valuable; E gates on dynamic SPG).

§1 status line: "scenarios A, B, and C all green on GB200" replaces
the "scenario A green" wording.

Net for the PR-on-PR: A and B are the strongest evidence (overlay
is additive without regressing data path). C's catalog entry plus
the measured per-push BW round out the picture. D and E are honest
follow-up axes.

Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Rebased onto current main (was 3 weeks stale; resolved one trivial
__all__ merge conflict in modelexpress/__init__.py).

Python — correctness fixes:

1. refit_receiver.poll_for_source: was hardcoding training_step=0
   on the returned SourceRef and never filtering on min_step despite
   advertising both in the docstring. ListSourcesResponse instances
   carry only SourceInstanceRef (no extra_parameters), so the actual
   training_step lives on SourceIdentity in the publisher's metadata.
   Now do a per-candidate get_metadata() lookup, parse training_step
   from SourceIdentity.extra_parameters, and skip candidates whose
   step is below the threshold or unparseable. Cost: extra gRPC
   round-trip per candidate; can be removed once training_step is
   surfaced on SourceInstanceRef directly.

2. training_publisher.initialize(): training_framework was
   hardcoded to "prime_rl" in _build_identity, which mislabeled
   verl-published sources. Now a parameter on initialize() (default
   "unknown" so callers know to set it explicitly).

3. training_publisher publish_weights / publish_layer mutual
   exclusivity: publish_layer registers fresh tensors every call but
   publish_weights caches via self._registered, so interleaving the
   two paths could leave NIXL holding only the most-recently-
   registered tensor set. New self._publish_mode tracks which path
   is in use; either method raises if the other was already used on
   this publisher.

4. refit_receiver._DTYPE_MAP: lifted to module scope (was rebuilt
   per call inside receive_weights_scratch).

Docs — content fixes:

5. VERL_MX_OVERVIEW.md deployment-mode table: replaced ❌ / ✅
   emoji markers with plain text per repo "no emojis in markdown"
   guideline.

6. PRIMERL_MX_OVERVIEW.md §3.9: fixed duplicate "byte-exact
   byte-exact" → "byte-exact".

7. MD040: annotated 15 bare ``` fences across MX_RL_OVERVIEW.md,
   PRIMERL_MX_OVERVIEW.md, VERL_MX_OVERVIEW.md,
   PRIMERL_MX_NATIVE_DESIGN.md, mx-rl-integration-slides.md as
   ```text where they were carrying plain prose / ASCII layout.

8. ASCII → mermaid:
   - MX_RL_OVERVIEW.md §Architecture: ASCII trainer/server/inference
     swimlane → sequenceDiagram.
   - PRIMERL_MX_OVERVIEW.md §3.2 DAG buildup: 5-phase ASCII timeline
     → flowchart with one subgraph per phase, plus a per-phase
     bandwidth table.
   - PRIMERL_MX_OVERVIEW.md §3.9 before/after: naive-allgather vs
     overlay per-rank flow → side-by-side flowchart.
   - Slide-deck ASCII bottleneck-bar / 3-column architecture
     intentionally retained: those are CSS-styled visual fallbacks
     for the SVGs, not Markdown rendering targets. The misleading
     "[ INSERT DIAGRAM: diagram-architecture.svg ]" placeholder text
     above the architecture fallback was removed.

No proto / server-side changes. The poll_for_source fix is the
proto-level workaround documented in the CodeRabbit review; the
forward-looking fix (adding training_step directly to
SourceInstanceRef so we don't need the per-candidate get_metadata)
is a follow-up.

Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Adds NIXL_COMPRESSION_STUDY.md to help the NIXL nvCOMP compression
team reproduce our RL weight-transfer payloads using our validated
PRIME-RL and verl workflows.

Three paths documented:
1. Pre-captured data (fastest) — pointer to our existing Qwen2.5-1.5B
   data package (model.safetensors + pre/post RL weights + deltas +
   KV cache, captured from live GB200 deployment).
2. End-to-end reproduction on GB200 via the PRIME-RL overlay (PR
   PrimeIntellect-ai/prime-rl#2343) — deploy scenario A, exec into
   trainer pod, capture state_dict + simulate one RL step + dump KV
   cache. Step-by-step with kubectl commands.
3. Reproduction via verl MxCheckpointEngine (PR ai-dynamo/modelexpress
   #252) — same tensor content, different transport path.

Also covers: compression-relevant properties table, per-tensor layout
for Qwen3-0.6B and Qwen2.5-1.5B, delta analysis notes (BF16 deltas
mostly zero at RL learning rates; FP32 diffs are the meaningful
analysis target), NIXL integration point for nvCOMP (transparent —
compress/decompress at the NIXL layer, no MX or framework changes),
and a model-size scaling table for larger captures.

Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor
The pre-captured Qwen2.5-1.5B data package referenced in Option 1 of
NIXL_COMPRESSION_STUDY.md isn't in this repo (binary tensors at GB
scale aren't appropriate to commit) and the path I had previously
shown was an internal local checkout. Replace with explicit "request
from kavink@nvidia.com" framing and call out the appropriate channels
(NV S3, internal share, or direct upload to eschmidt@nvidia.com per
the original ask). Add the total package size (~14 GB) so the NIXL
team knows what to expect bandwidth-wise.

Update the "larger models" cross-reference accordingly.

Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor
Adds the two scripts that produced the Qwen2.5-1.5B data package we
referenced in NIXL_COMPRESSION_STUDY.md so the NIXL team can reproduce
captures themselves on different models / sequence lengths / clusters
without going through us as a manual relay.

New files:
  docs/RL/scripts/capture_weights_and_kv.py
    Standalone — any HF model, any host (CPU or single GPU), no
    cluster / RL framework needed. CLI flags for model, dtype,
    device, output dir, weights/KV-only modes, KV seq_len.

  docs/RL/scripts/capture_on_pod.py
    Inside-a-running-RL-pod variant. Generalized vs the original
    Qwen2.5-1.5B-only capture: --model, --out, --kv-seq-len, --lr
    flags. Captures pre/post RL weights + simulated AdamW step
    delta + KV cache in one pass. Produces the four-directory
    layout (weights_pre_rl/, weights_post_rl/, weight_deltas/,
    kv_cache/) we shipped to the compression team.

  docs/RL/scripts/README.md
    Quick reference for both scripts: when to use each, complete
    CLI examples, output layout, the BF16-deltas-are-mostly-zero
    note + FP32 analysis snippet, pointer back to the main study
    doc.

Updated:
  docs/RL/NIXL_COMPRESSION_STUDY.md
    Option 2 now points at scripts/capture_on_pod.py with kubectl
    cp + exec invocation instead of an inlined heredoc Python
    block. Added Option-2-Step-4 ("standalone capture without a
    running RL deployment") pointing at capture_weights_and_kv.py
    for users who don't want to deploy the full overlay.

The original on-disk capture scripts in our internal recovery
directory are unchanged; this just publishes a generalized,
flag-driven version of each into the public docs tree.

Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor
Two changes to NIXL_COMPRESSION_STUDY.md:

1. Add a component-view mermaid diagram at the top showing where the
   compression-target tensors actually live (RL refit edge between
   trainer and inference NIXL agents; KV cache edge between prefill
   and decode), with green nodes / edges marking the compression
   surface and purple marking RL-stack infrastructure that wouldn't
   change if nvCOMP slots into the NIXL layer transparently.

2. Drop GKE/cluster-specific assumptions. Previously Option 2 named a
   specific GKE node pool, namespace, registry, and tsh auth flow as
   prerequisites; now it just says "a GB200 cluster (ARM64) with at
   least 2 nodes, container runtime, RDMA-capable interconnect". The
   K8s manifests are flagged as examples that need light edits (ns,
   node selectors, registry, RDMA network annotations) per cluster.
   Hardcoded "kavin" namespace replaced with $NS=<your-namespace>
   throughout the kubectl commands so a copy-paste of the recipe
   works on any cluster.

The capture flow itself was already cluster-agnostic — these edits
just stop the doc reading like it's only reproducible on our exact
GKE shape.

Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor
Removes both references to eschmidt@nvidia.com from
NIXL_COMPRESSION_STUDY.md so the guide reads as a general team-facing
doc rather than addressed at one inbox. Audience line now just says
"NIXL compression team"; Option 1 channel list trims "direct upload
to your eschmidt@nvidia.com inbox per the original request" down to
"direct upload" — same channel options, no person-specific routing.

Single contact for the data package remains kavink@nvidia.com.

Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor
…filter, tree fan-out

Adds the MX-side support for the NemoRL integration design (see
pensieve/RL/NemoRL/04_design_v2_moe_rank_to_rank.md). Built on top of the
existing MxTrainingPublisher / MxRefitReceiver as a Python-only shim, so
this lands without Rust server changes.

What's new:
* shape_descriptors: TensorDescriptorV2 + DTensor placement -> wire format
  helpers. Handles Replicate / Shard / Partial; computes per-rank local
  shard ranges; supports MoE expert axis with owned_expert_ids.
* nemo_rl_v2: MxV2TrainingPublisher / MxV2RefitReceiver / TrainerWorldLayout
  / V2SourceCandidate. Implements:
  - rank-to-rank publish (each rank publishes its own local shard, no
    allgather)
  - same-rank-only routing (PrimeRL GB200 lesson: avoids cross-NIC subnet
    writes that were causing NIXL_ERR_REMOTE_DISCONNECT)
  - freshest-per-rank dedup by updated_at (avoids stale READY peers from
    orchestrator restart cascades, which were causing NIXL_ERR_NOT_ALLOWED)
  - tree fan-out via publish_self_as_source (TensorHub paper 2604.09107v1
    pipeline replication)
  - MoE expert coverage filter for receivers in EP layouts
  - auto heartbeat via existing HeartbeatThread

v2 metadata transport (3-tier fallback):
  1. SourceIdentity.extra_parameters via meta.identity (cleanest; requires
     Rust server to populate the new identity field on
     GetMetadataResponse)
  2. Synthetic TensorDescriptor sidecar __mx_v2_meta__ with JSON in dtype
     field (the path that works against the current server, which drops
     SourceIdentity and most string fields when echoing WorkerMetadata)
  3. WorkerMetadata.agent_name string-encoded marker (legacy fallback)

Proto change: added SourceIdentity identity = 5 to GetMetadataResponse.
Backward-compatible (older clients ignore). Python stubs regenerated.

Tests:
* 15 unit tests covering shape descriptor round-trip (Replicate, Shard,
  MoE expert), expert codec, world-layout codec, picker filtering
  (same-rank, min-version, mx_v2 marker, expert coverage, trainer
  fallback), DAG fan-out, and agent_name fallback for legacy servers.
* scripts/v2_moe_e2e_demo.py: standalone GB200 cluster demo. Validated
  on dynamo-gcp-dev-02 in kavin namespace: 4 ranks x 2 cycles, real
  NIXL RDMA, same-rank routing, freshness dedup, MoE expert sharding,
  sidecar transport. All 8 transfers correct.

Companion changes in NVIDIA/RL on branch kavink/mx_integration adopt
this client.
…lients

Adds the SourceIdentity round-trip that v2 RL clients (NemoRL
update_weights_via_mx, prime-rl PR #2389 follow-ups) need to read
framework-level state from extra_parameters.

Changes:
* metadata_backend.rs: ModelMetadataRecord gains an Option<SourceIdentity>.
  Old records (pre-v2 storage) leave it None.
* p2p_service.rs::get_metadata: populates GetMetadataResponse.identity
  from the record's identity field (the new proto field added in the
  preceding commit on the client branch).
* metadata_backend/redis.rs: SourceAttributesJson gains an
  extra_parameters HashMap (with #[serde(default)] for back-compat) and
  a to_source_identity() method that reconstructs the full
  SourceIdentity from the stored attributes hash. Old records without
  extra_parameters deserialize cleanly.
* metadata_backend/kubernetes.rs: identity stays None for now; CRD
  schema bump is a separate change. v2 clients fall back to the
  sidecar transport (synthetic TensorDescriptor) until then.
* state.rs + p2p_service.rs test fixtures: identity: None added to
  match the new struct shape.

This change is forward-compatible: pre-existing records read into
ModelMetadataRecord with identity=None and old clients ignore the
new GetMetadataResponse.identity field.

Pairs with the proto change in commit 97c0e78 that added
SourceIdentity identity = 5 to GetMetadataResponse, and with the
v2 NemoRL client which already prefers identity.extra_parameters
when available and falls back to the synthetic-tensor-descriptor
sidecar otherwise.

Build/deploy: requires rebuilding modelexpress-server image and
redeploying. The current image at
nvcr.io/nvidian/dynamo-dev/modelexpress-server:latest predates these
fields; the v2 RL prototype works against it via the sidecar transport.
Comprehensive companion doc for the NemoRL upstream PR. Covers:
* motivation (PrimeRL GB200 lessons, TensorHub paper, Composer 2)
* the four design pillars (rank-to-rank publish, tree scale-out, MoE
  expert filtering, explicit shape registry)
* full Python API surface with worked code snippets
* file inventory across kavink/nemo_rl_moe (MX) and
  kavink/mx_integration (NemoRL)
* the three-tier metadata transport workaround for the running server
  (SourceIdentity → __mx_v2_meta__ sidecar → agent_name fallback)
* what was tested:
  - 15/15 unit tests passing (no GPU/NIXL)
  - live cluster gRPC smoke
  - live E2E on GB200, toy scale (4×2 cycles, byte-correct)
  - live E2E on GB200, production scale (4×1.6 GB Qwen3-30B-A3B-shaped
    in 11–16 ms per transfer)
  - arm64 Docker overlay built + smoke-tested
  - explicit list of what was NOT exercised
* server-side patch path (rebuild image to land identity round-trip)
* deployment recipe (config.yaml + expected log lines)
* roadmap (async refit, Megatron, SGLang, dirty-experts, cross-DC,
  drain semantics)

Self-contained for upstream review; replaces no existing doc.
The doc is intended to be self-contained for upstream review; the
internal-pensieve cross-references would dangle for outside readers.
Replaced with public references (PR links + sister docs already in
docs/RL/).
…ensors

Two bugs caught by a real-DTensor (not faked) e2e test on GB200, where
the publisher wraps a torch.distributed.tensor.DTensor instead of a
stand-in object:

1. shape_descriptors.describe_tensor: the SHARD branch assumed
   tensor.shape was the LOCAL view, which is true for plain tensors
   but FALSE for real DTensors — DTensor.shape is the global,
   un-sharded shape. With FSDP=4 and HIDDEN=1024, this caused
   global_shape to be computed as 1024*4=4096 (wrong) and
   local_shard_range as rank*1024 instead of rank*256. Fix: detect
   torch.distributed.tensor.DTensor and use tensor.to_local().shape[dim]
   as local_extent. Plain-tensor and stand-in-DTensor paths are
   unchanged.

2. nemo_rl_v2.MxV2TrainingPublisher: shape_registry was only being
   transmitted via SourceIdentity.extra_parameters, which the running
   Rust server drops. Sidecar transport (the synthetic
   __mx_v2_meta__ TensorDescriptor) carries the v2 marker but not the
   registry. Receivers had no per-tensor placement info. Fix: include
   shape_registry in the sidecar JSON (nested-string ok, JSON encoder
   handles it). Receiver-side parsing already handles
   extra["shape_registry"].

Validated end-to-end on prime-rl-nixl-mx-trainer-0 (GB200, 4 GPUs):
4 ranks × 2 refit cycles, real torch.distributed mesh + Shard(0)
placement. Registry now reports correct global=(1024,2048) +
local_range=(rank*256, (rank+1)*256) per rank. All 8 transfers
all_elem_match=True (every byte of every received local shard equals
its rank+version sentinel).

Companion: scripts/v2_dtensor_e2e_demo.py exercises this codepath.
15/15 unit tests still pass (their stand-in fake-DTensor hits the
plain-tensor branch of describe_tensor).
KavinKrishnan and others added 6 commits May 22, 2026 06:41
Companion to v2_moe_e2e_demo.py that exercises the codepath that
DTensorPolicyWorker.stream_weights_via_mx uses in production: real
torch.distributed.tensor.DTensor on a (WORLD_SIZE,) mesh + Shard(0)
placement, MxV2TrainingPublisher.add_tensor reading .placements off
the DTensor, NIXL register on the local view, sidecar transport
round-trip, byte-level correctness on the same-rank pull.

Asserts:
  * registry's global_shape == DTensor.shape (un-sharded)
  * registry's local_shard_range == (rank*chunk, (rank+1)*chunk)
  * received local shard byte-matches every-cell sentinel
Documents the just-validated codepath: 4 ranks × 2 cycles, real
torch.distributed.tensor.DTensor on a (WORLD_SIZE,) mesh + Shard(0)
placement. All 8 transfers byte-correct (all_elem_match=True), and
the receiver's reconstructed shape registry reports correct
global=(1024,2048) + per-rank local_range=(rank*256, (rank+1)*256).

Also points out the two real bugs the test caught (DTensor.shape
semantics, sidecar shape_registry) and renumbers the existing §8.5+
sections accordingly. The §8.7 'NOT validated' entry on
DTensorPolicyWorker.stream_weights_via_mx is updated — the v2 protocol
mechanics are now exercised; only NemoRL-specific outer glue (HF
state_dict walk, MoE expert-name heuristic, cpu_offload lifecycle)
remains as integration-level work.
#295)

Signed-off-by: John Thomson <jothomson@nvidia.com>
… in loopback

Both v2 demo scripts (v2_moe_e2e_demo and v2_dtensor_e2e_demo) ran on
GB200 via UCX's intra-node cuda_ipc fast path, which silently tolerates
malformed prep_xfer_dlist entries. That's why neither demo caught the v2
sidecar (__mx_v2_meta__, addr=0, size=0) leak that PR #295 (commit
53c69ec) fixed — the bad descriptor only trips UCX's validator on real
cross-node rc_mlx5 / cuda_copy paths, which is what jthomson04 hit on
GB300 RoCE during Dynamo bring-up.

FORCE_RDMA=1 sets UCX_TLS=self,sm,rc_mlx5,cuda_copy,tcp (omitting
cuda_ipc), routing intra-node demos through the same strict validator
that cross-node would use. Pre-deploy runs should set this so future
descriptor-list bugs of this shape surface in loopback rather than
waiting for a real cross-host deployment.

Usage:

  FORCE_RDMA=1 WORLD_SIZE=4 python3 v2_moe_e2e_demo.py

Background: pensieve/RL/NemoRL/07_dynamo_handoff_2026_05_18.md
§"The real root cause: v2 sidecar leaking into RDMA descriptor list".
…source signature

After rebasing kavink/nemo_rl_moe onto current main, two API skew issues
surfaced that the rebase mechanical-merge couldn't see:

1. `nemo_rl_v2.py` imports `HeartbeatThread` from the flat
   `modelexpress.heartbeat` path, but main moved that module into the
   `modelexpress.metadata` subpackage. Crashed with
   `ModuleNotFoundError: No module named 'modelexpress.heartbeat'` on the
   trainer when DTensorPolicyWorker.stream_weights_via_mx first imports
   nemo_rl_v2. Fix: import from `.metadata.heartbeat`.

2. `MxRefitReceiver.receive_weights_scratch` passes
   `coalesce_transfers=False` to `NixlTransferManager.receive_from_source`.
   Main removed that parameter (coalescing is now gated entirely by
   `MX_POOL_REG`), so the call raises
   `TypeError: receive_from_source() got an unexpected keyword argument
   'coalesce_transfers'`. Worker logs showed the receiver looping with
   "[mx-poller] refit failed for version N; will retry" while the trainer
   silently proceeded on stale weights (the dynamo extension acks the
   refit RPC immediately and runs the actual receive in a background
   poller). Fix: drop the obsolete kwarg.

E2E validated on Qwen3-4B-Thinking + Dynamo vLLM v1 GRPO smoke, both
refit cycles complete with clean version transitions:

  step=1: RDMA transfer complete: 8.82 GB, 399 tensors, 0.20s, 357.5 Gbps
  step=2: RDMA transfer complete: 8.82 GB, 399 tensors, 0.18s, 384.4 Gbps
  [mx-poller] refit OK to version 1
  [mx-poller] refit OK to version 2

Signed-off-by: John Thomson <jothomson@nvidia.com>
…ulti-source slice planner

Implements Phases 3a, 3b, and 4 of the post-PR-#2389 RFC (KavinKrishnan/prime-rl:
kavink/post-2389-kernel-compile-plan, RFC §3.3 and §5):

Phase 3a — extend TensorDescriptorV2 with `compile_target` and
`compile_metadata`. Default `compile_target=hf_raw` keeps the wire byte-compat
with pre-Phase-3 candidates: encode_registry omits the field when it's the
default. Adds a `compile_target_matches` helper for receiver-side filtering
(whitelist + required-metadata-subset). New canonical string constants:
hf_raw, vllm_fused, deep_gemm_fp8, cutlass_fp8, trtllm.

Phase 3b — extend MxV2RefitReceiver.discover_v2_sources with
`compile_target_filter` and `required_compile_metadata`. Candidates without a
v2 registry are rejected when either is set (we can't certify bytes blindly).
Candidates with mixed compile targets are rejected if any tensor's target is
outside the allowed set. V2SourceCandidate now exposes `compile_targets:
frozenset[str]` for caller introspection.

Phase 4 — multi-source slice discovery for mixed trainer/inference TP:

- New types: TargetTPLayout, SliceSource, SliceCoveragePlan.
- New method MxV2RefitReceiver.discover_v2_sources_for_slice(target_layout=…):
  walks all v2 candidates per tensor, intersects their local_shard_range with
  the receiver's requested slice, emits a minimal SliceSource list. Detects
  coverage gaps + shard_axis mismatches and surfaces them in plan.missing.
- New method MxV2RefitReceiver.receive_via_plan(plan): orchestrates the
  scratch RDMA pulls per contributing candidate and stitches results with
  torch.cat along the shard axis. v0 issues one full scratch fetch per
  candidate; byte-level partial RDMA is a follow-up (RFC §5 Phase 4.5).

Unit tests (33 total, all green):
- Phase 3a: compile_target default, round-trip, wire-omission-when-default,
  matches helper with whitelist and required-metadata-subset (6 tests).
- Phase 3b: filter accepts matching / unset admits all / rejects when no
  registry / required-metadata pinning works (4 tests).
- Phase 4: planner covers within-shard slice, planner stitches cross-shard
  slice, planner picks freshest for REPLICATE, planner flags coverage gap,
  planner flags shard_axis mismatch (5 tests). receive_via_plan stitches
  two sources byte-exactly, passes through single source, refuses uncovered
  plan (3 tests).

Module __init__.py re-exports: TargetTPLayout, SliceCoveragePlan, SliceSource,
V2SourceCandidate, TensorDescriptorV2, compile_target_matches, all
COMPILE_TARGET_* constants.

Does not modify the existing single-source receive paths or the v1 fat
clients, so existing demos (v2_dtensor_e2e_demo.py, v2_moe_e2e_demo.py) and
downstream consumers (NemoRL kavink/mx_integration, jthomson04/RL Dynamo path)
continue to compile + run unchanged.
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented May 27, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

…_tensor

Phase-3 graduation glue. After this commit, callers of MxV2TrainingPublisher
can tag each tensor with its kernel layout at publish time, and the tag
flows end-to-end into the v2 sidecar TensorDescriptor's compile_target +
compile_metadata fields. Receivers can then filter via
discover_v2_sources(compile_target_filter=..., required_compile_metadata=...)
without needing to inspect the raw bytes.

Callers will typically pass these from their trainer-side conversion
registry. For prime-rl, that's
prime_rl.trainer.models.conversions.ConversionEntry.{compile_target,
compile_metadata} (Phase 3 trainer-side PR on
KavinKrishnan/prime-rl#kavink/post-2389-conversion-registry-extensions).

Defaults preserve back-compat: compile_target='hf_raw', compile_metadata={}.
Existing callers that don't pass either kwarg get the unchanged
behaviour and the descriptor's wire form is byte-identical to before
(encode_registry omits these fields when at defaults).

Tests: 35/35 green. Two new tests:
- test_phase3_add_tensor_threads_compile_target: 3-tensor mix
  (lm_head=hf_raw default, gate_proj=cutlass_fp8 with metadata,
   experts=deep_gemm_fp8 with block_size metadata) all flow correctly
  into the publisher's internal registry.
- test_phase3_add_tensor_compile_target_survives_encode_decode:
  end-to-end wire round-trip via encode_registry + decode_registry.

Doesn't change any existing call sites; this is the API extension that
makes the Phase 3 PR consumable.
…dapter for vLLM native RL APIs

Two additive surfaces on top of the existing Phase 2/3/4 work.

----------------------------------------------------------------------
1. Transfer-metrics instrumentation
----------------------------------------------------------------------

New TransferStats dataclass on the receiver side captures per-receive
metrics (bytes, tensors, elapsed, bandwidth_gbps, discovery_seconds,
path, training_step, source_worker_rank) in structured form. All three
MxRefitReceiver receive paths populate it:

  - receive_weights        → path="pre_registered"
  - receive_weights_scratch → path="scratch"
  - receive_weights_from_metadata → path="from_metadata"

Exposed on the receiver as `last_stats` (latest call) and `history`
(full per-call list). The v2 layer additionally tracks
`_last_discovery_seconds` (control-plane round-trip time, distinct from
data-plane RDMA time) so benchmarks can compare them.

The numbers were already produced inline as log lines; this commit
captures them as queryable state so benchmark harnesses don't have to
parse logs.

----------------------------------------------------------------------
2. MxWeightTransferEngine — vLLM native RL APIs adapter (Option A from
   pensieve/RL/PrimeRL/10_*.md)
----------------------------------------------------------------------

New module modelexpress/vllm_weight_transfer.py implements the
WeightTransferEngine abstract base introduced in the 2026-05-28 vLLM
Native RL APIs blog. RL frameworks pick it up via:

  import modelexpress.vllm_weight_transfer  # registers "mx_nixl"
  llm = LLM(..., weight_transfer_config=WeightTransferConfig(backend="mx_nixl"))

Three info dataclasses on the wire:

  MxInitInfo         — mx_server_url, model_name, worker_rank,
                       agent_name, device_id, publish_self_as_replica
  MxUpdateInfo       — version, target_tp_layout (Phase 4),
                       compile_target_filter (Phase 3b),
                       required_compile_metadata (Phase 3b),
                       same_rank_only + dedup_freshest_per_rank (Phase 2)
  MxTrainerSendArgs  — publisher, version, compile_target,
                       compile_metadata, per-tensor expert metadata

Dispatch logic in receive_weights:

  - target_tp_layout=None → matched-TP fast path: discover_v2_sources +
    pick_best_source + receive_from (single source, same-rank)
  - target_tp_layout=set → Phase-4 path: discover_v2_sources_for_slice
    + receive_via_plan (multi-source stitched)

Both paths apply the Phase 3 compile_target_filter +
required_compile_metadata at discovery time, refusing incompatible
sources BEFORE spending any RDMA cycles.

After a successful receive, optionally calls publish_self_as_source for
tree fan-out / pipeline replication (TensorHub pattern) — controlled by
MxInitInfo.publish_self_as_replica (default True). Failure of this
best-effort optimization does NOT propagate.

Trainer-side classmethod trainer_send_weights threads the compile_target
+ compile_metadata + per-tensor expert metadata from MxTrainerSendArgs
into each MxV2TrainingPublisher.add_tensor() call, then publishes once
with the version.

Registration with vLLM is via WeightTransferEngineFactory at import
time, with a try/except so the module is import-safe in environments
without vLLM (tests, publisher-only nodes, benchmark harnesses). The
MX_WEIGHT_TRANSFER_AUTOREGISTER=0 env var disables auto-registration
even when vLLM IS installed.

Metrics surface exposed for benchmarks:

  engine.last_transfer_stats     — TransferStats from most recent receive
  engine.transfer_history        — full per-call history
  engine.last_discovery_seconds  — most recent control-plane round-trip

----------------------------------------------------------------------
Tests
----------------------------------------------------------------------

14 new unit tests in test_vllm_weight_transfer.py (all green). Categories:

  - construction (with + without init_info, error before init)
  - matched-TP fast path (yielded tensors reach load_weights)
  - mixed-TP Phase 4 path (slice plan built + stitched)
  - uncovered-plan rejection (Phase 4 fully_covered=False raises)
  - no-source-matches-filter rejection (Phase 3b fast path)
  - kwarg passthrough (compile_target_filter, required_compile_metadata,
    same_rank_only, dedup_freshest_per_rank all reach the receiver)
  - publish_self_as_replica triggered post-receive
  - publish failure swallowed (tree fan-out is best-effort)
  - trainer_send_weights threads compile_target/compile_metadata/expert
    metadata to add_tensor + publishes with version
  - metrics surface returns sensible Nones / empties pre-init
  - factory contract: init_info_cls + update_info_cls class attrs

All 49 v2 tests still green (35 prior + 14 new). Companion design doc
at pensieve/RL/PrimeRL/10_mx_weight_transfer_engine_design.md.
…rk suite

Adds the three-scenario transport-layer benchmark we need to demonstrate
the MX v2 + MxWeightTransferEngine integration on real hardware:

  - elastic_scale  — N receivers join staggered; measures cold-start
                     join latency, per-cycle Gbps, control-plane vs
                     data-plane latency split
  - compile_target — three concurrent receivers (matched filter,
                     mismatched filter, no filter) prove the Phase 3b
                     safety net refuses incompatible bytes BEFORE RDMA
                     and that the no-filter path stays back-compatible
  - tree_fanout    — same as elastic but receivers also
                     publish_self_as_source; measures trainer egress
                     vs total delivered (fanout factor)

The harness drives both ends through the new MxWeightTransferEngine
adapter, so the numbers reflect what RL frameworks will see going
through vLLM's native WeightTransferEngine interface — not a special
private API.

Two run modes:

  - --mode=live (default) launches the trainer + receivers as
    subprocesses against a real MX server + NIXL data plane. Needs
    CUDA + NIXL + a reachable mx-server URL.
  - --mode=cpu  is a stubbed orchestrator-only smoke. Exercises
    result aggregation + summary table generation without touching
    MX or torch.cuda. Used in CI and for local development.

Output: human-readable summary table on stdout + machine-readable
JSON via --output (schema documented in benchmarks/README.md).

Companion changes:

  - MxV2RefitReceiver gains receive_from_scratch(candidate): wraps
    receive_weights_scratch so callers without pre-registered model
    buffers (the benchmark, cold-start vLLM workers) can still
    drive the matched-TP path
  - MxWeightTransferEngine switches the matched-TP path to use
    receive_from_scratch by default, matching the comment that was
    already in the source: scratch mode is the right cold-start
    behavior and matches Anyscale's RDT plugin pattern. When vLLM
    exposes register_destinations (proposed extension §5.1 in the
    design doc), this can switch to the zero-copy path.
  - The two engine tests that monkeypatched receive_from are updated
    to monkeypatch receive_from_scratch instead

Tests:

  - 9 new orchestrator unit tests cover bandwidth math, trainer-egress
    accounting (both with and without fan-out), compile_target verdict
    derivation, percentile helpers, summary-table conditional rendering,
    CLI arg parsing, and an end-to-end --mode=cpu CLI run that writes
    JSON to a tmp path
  - All 58 v2 unit tests still green (35 shape/picker + 14 engine + 9
    bench)

Companion doc at pensieve/RL/PrimeRL/11_benchmark_results.md with the
methodology + acceptance criteria for each scenario; numbers tables
left as PENDING until the cluster Teleport access is refreshed.
Lets us go from "Teleport auth refreshed" → "three JSON result files
in hand" with one command:

    ./run_cluster_bench.sh

The Job manifest (k8s/bench-elastic.yaml) runs all three scenarios in
sequence inside one pod in the kavin namespace. The pod uses the
existing prime-rl image (any image with modelexpress installed); the
harness lives at modelexpress/benchmarks/bench_elastic_scaling.py
inside the image, so no separate image build is needed.

The driver script (run_cluster_bench.sh):

  - cleans up prior Job runs
  - applies the manifest
  - waits for the pod to start (up to 2 min)
  - optionally tails logs (--watch)
  - waits for Job completion (30 min cap)
  - kubectl cp's /results/ out to ./results-<timestamp>/
  - prints a per-scenario summary

Resource request is 5 GPUs (1 trainer + 4 receivers for scenarios 1
and 3); change nvidia.com/gpu in the manifest if your namespace has
different quota. Tolerations + nodeSelector target the GB200 pool by
default; comment them out or replace for other environments.

When run, results files map directly to the tables in
pensieve/RL/PrimeRL/11_benchmark_results.md — paste the per-receiver
rows from the JSON files into the matching sections.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants