RL: MX weight-refit clients + integration design docs (PRIME-RL & verl)#252
RL: MX weight-refit clients + integration design docs (PRIME-RL & verl)#252KavinKrishnan wants to merge 25 commits into
Conversation
WalkthroughIntroduces comprehensive RL weight-synchronization documentation for PRIME-RL and verl framework integrations via ModelExpress, alongside two new Python modules: Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~25 minutes Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Actionable comments posted: 11
🧹 Nitpick comments (3)
modelexpress_client/python/modelexpress/refit_receiver.py (3)
283-290: Lift_DTYPE_MAPto module scope.The dtype lookup table is rebuilt on every
receive_weights_scratchcall. Making it a module-level constant both makes intent clearer and lets other methods (e.g. a futurereceive_weightsthat needs to validate source dtype) share it.🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelexpress_client/python/modelexpress/refit_receiver.py` around lines 283 - 290, The _DTYPE_MAP dictionary is currently created inside receive_weights_scratch causing it to be rebuilt on every call; move the _DTYPE_MAP definition to module scope (as a top-level constant) so it’s constructed once and can be reused by receive_weights_scratch and other functions (e.g., a future receive_weights) that need dtype validation; update references in receive_weights_scratch to use the module-level _DTYPE_MAP.
230-232: Don't reach intoNixlTransferManager._tensors.Both
receive_weights(L231) andreceive_weights_from_metadata(L360) readself._nixl._tensors— a private attribute. IfNixlTransferManagerever changes its internal storage (e.g. to a ref-counted dict or a two-tier cache), this silently breaks. Add a public accessor (get_registered_tensor(name)orregistered_tensorsproperty) onNixlTransferManagerand call that instead.Also applies to: 359-361
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelexpress_client/python/modelexpress/refit_receiver.py` around lines 230 - 232, The loops in receive_weights and receive_weights_from_metadata are directly accessing the private attribute NixlTransferManager._tensors; instead add a public accessor on NixlTransferManager (e.g., get_registered_tensor(name) or a registered_tensors property) that returns either a single tensor by name or an iterable/dict of registered tensors, then update receive_weights and receive_weights_from_metadata to call that public API (use get_registered_tensor(td.name) or check membership via registered_tensors) instead of reading _tensors directly so the implementation is decoupled from NixlTransferManager's internals.
292-329:scratch_shapesis populated but never read; andtensor_shapesreshape may silently mismatch numel.Two minor cleanups:
scratch_shapes(lines 293, 301) is a dead variable.- At L328,
tensor.view(tensor_shapes[name])will raise a non-obvious error if the caller-supplied shape's product doesn't equaltd.size // elem_size. Since the callers (vLLM/HF checkpoint pathways) are the ones passingtensor_shapesbased on safetensors headers, consider asserting the numel match with a clear message before calling.view.- scratch_tensors: dict[str, torch.Tensor] = {} - scratch_shapes: dict[str, tuple[int, ...]] = {} - for td in source_tensors: + scratch_tensors: dict[str, torch.Tensor] = {} + for td in source_tensors: dt = _DTYPE_MAP.get(td.dtype, torch.bfloat16) elem_size = torch.tensor([], dtype=dt).element_size() numel = td.size // elem_size scratch_tensors[td.name] = torch.empty( numel, dtype=dt, device=f"cuda:{self._device_id}" ) - scratch_shapes[td.name] = (numel,)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@modelexpress_client/python/modelexpress/refit_receiver.py` around lines 292 - 329, Remove the unused scratch_shapes map (created near scratch_shapes: dict[...] and populated with scratch_shapes[td.name]) since it is never read; then, before reshaping in the loop that yields name, tensor, add an explicit assertion that the caller-provided tensor_shapes[name] product equals the allocated tensor.numel() (compute expected numel from source_tensors' td.size and elem_size or use tensor.numel()) and raise a clear error message referencing the tensor name and mismatched sizes if not equal, then call tensor.view(tensor_shapes[name]); keep usages of scratch_tensors, source_tensors, td, elem_size, and the RDMA call self._nixl.receive_from_source unchanged.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Inline comments:
In `@docs/MX_RL_OVERVIEW.md`:
- Around line 174-177: MxTrainingPublisher currently hardcodes
"training_framework": "prime_rl" inside _build_identity causing non-prime
integrations to publish wrong tags; modify MxTrainingPublisher so
training_framework is passed through initialize() (or constructor) and stored as
an instance attribute, update _build_identity to read self.training_framework
instead of the hardcoded string, and ensure both PRIME-RL and verl callers set
the appropriate framework value when calling initialize().
- Around line 27-28: Add a language identifier (e.g., text) to the fenced code
blocks that currently open with ``` so markdownlint MD040 is satisfied;
specifically update the fence before the block containing "Trainer GPU
MX Server (gRPC + Redis) Inference GPU" and the other two similar
fenced blocks in the document (the ones that mirror that header layout) to use
```text instead of ``` so they become ```text and preserve the inner content
unchanged.
- Around line 27-48: Replace the ASCII art sequence diagram with a mermaid
sequenceDiagram block that preserves the same participants and ordered steps:
Trainer GPU (call out optimizer.step()), MX Server (gRPC + Redis)
(publish_weights(), get_metadata()), and Inference GPU (poll_for_source(),
model.load_weights()); map steps 1–6 (optimizer.step, publish_weights,
poll_for_source, get_metadata, NIXL RDMA READ / GPU-to-GPU transfer,
model.load_weights) into mermaid messages and notes so rendering matches the
original flow and retains the NIXL RDMA READ/GPU-to-GPU data transfer comment;
ensure the fenced block uses ```mermaid to render in Markdown and replace the
entire ASCII diagram block.
In `@docs/RL/PRIMERL_MX_OVERVIEW.md`:
- Around line 218-241: The ASCII art timeline titled "DAG buildup over time" and
the "Before / After per-rank publishing" diagrams should be converted to mermaid
diagrams (use graph LR or sequenceDiagram as appropriate) so they render
consistently with the rest of the document; replace the ASCII blocks at the "DAG
buildup over time" and the per-rank publishing sections with mermaid fenced code
blocks (```mermaid ... ```), modeling nodes like Trainer, R0..R11 and the
publish/pull arrows and timeline steps t=0..t3; also add language identifier
`text` to the smaller state-schema and deployment-shape fenced blocks (the
state-schema and deployment-shape blocks noted in the comment) to satisfy
markdownlint MD040. Ensure labeling and arrow directions mirror the original
ASCII semantics and keep node names (Trainer, R0..R11, t=0..t3) unchanged.
In `@docs/RL/VERL_MX_OVERVIEW.md`:
- Around line 310-315: Remove the emoji markers in the deployment-mode table and
replace them with plain words; specifically update the status cells that
currently show ❌ and ✅ to textual statuses like "Not supported" and "Supported"
(or similar) for the rows referencing WorkerDict / execute_checkpoint_engine /
CheckpointEngineManager and Trainer / CheckpointEngineWorker, keeping the
explanatory sentence about this being a verl framework constraint and retaining
the references to the built-in engines `nixl` and `nccl` and the prototype note.
- Line 399: Remove the duplicated word "byte-exact byte-exact" in
docs/RL/VERL_MX_OVERVIEW.md and fix the sentence; additionally address the
"per-rank sharding-aware publishing" bullet by either moving that line to the
PRIME-RL overlay doc (docs/RL/PRIMERL_MX_OVERVIEW.md) if it describes PRIME-RL
behavior, or rephrase it in VERL_MX_OVERVIEW.md to clearly state how verl
implements or differs from "per-rank sharding-aware publishing" and whether the
§3.9-style guarantees apply to verl or only to the PRIME-RL overlay.
In `@docs/slides/mx-rl-integration-slides.html`:
- Around line 432-435: The slide contains an authoring placeholder text "[
INSERT DIAGRAM: diagram-architecture.svg ]" inside the <p> element within the
div.anim d3 block which ends up rendering to viewers; remove or replace that
placeholder by either embedding the actual SVG (replace the <p> placeholder with
the inline SVG markup or an <img src="diagram-architecture.svg">), or hide the
placeholder via CSS (target the .anim.d3 p fallback paragraph and set
display:none) so the fallback architecture diagram remains visible without
showing the authoring note.
In `@docs/slides/mx-rl-integration-slides.md`:
- Around line 29-63: Replace the ASCII art blocks (the bottleneck bar containing
"Rollout (40%) | Rew | Train (20%) | ██ REFIT (30%) ██" and the three-column
architecture block containing "Training Workers MX Server
Inference Workers" with mermaid diagrams (flowchart or sequence as appropriate)
that represent the same layout and include the component labels WeightExtractor,
MxTrainingPublisher, NIXL Agent, Metadata Coord, MxRefitReceiver, and NIXL
Agent; alternatively remove the ASCII fallback and embed the existing SVGs
(diagram-architecture.svg and diagram-rl-loop-bottleneck.svg) via standard
markdown image links, ensuring the slide uses mermaid or the SVGs per the repo
guideline.
In `@modelexpress_client/python/modelexpress/refit_receiver.py`:
- Around line 119-172: poll_for_source currently ignores min_step and returns
SourceRef with training_step=0 causing current_step to stay 0; update
poll_for_source (in refit_receiver.py) to filter instances by training_step >=
min_step and populate SourceRef.training_step from the publisher-advertised
value instead of hardcoding 0. Because the ListSourcesResponse.instances
(SourceInstanceRef) doesn't expose training_step, change the protobuf/api to
include training_step (or return SourceIdentity alongside each instance) so the
receiver can read extra_parameters["training_step"] (or the new training_step
field) and use it in the instance filter and when constructing SourceRef.
In `@modelexpress_client/python/modelexpress/training_publisher.py`:
- Around line 124-132: The hardcoded training_framework value in _build_identity
leaks into published metadata; update the class to accept a training_framework
in initialize (store it as self._training_framework) and replace the literal
"prime_rl" in _build_identity with that attribute so
SourceIdentity.extra_parameters uses self._training_framework; ensure
initialize's callers/constructors that create MxTrainingPublisher pass the
intended framework or default appropriately and that _build_identity continues
to return a p2p_pb2.SourceIdentity with extra_parameters including
"training_step" and the new "training_framework" value.
- Around line 204-256: publish_layer currently calls
self._nixl.register_tensors(...) every time but does not update/reset the
self._registered flag used by publish_weights, so if publish_layer and
publish_weights are interleaved NIXL can end up holding only a subset of
tensors; fix by enforcing mutual exclusivity or synchronizing registration
state: at the start of publish_layer (and similarly in publish_weights) assert
or raise if the other mode was previously used (check self._registered or a new
mode flag), or after calling self._nixl.register_tensors(...) update/reset
self._registered (or set a dedicated self._registration_mode flag) so
publish_weights will re-register when needed; reference publish_layer,
publish_weights, self._registered, register_tensors, and self._nixl when
implementing the check/reset.
---
Nitpick comments:
In `@modelexpress_client/python/modelexpress/refit_receiver.py`:
- Around line 283-290: The _DTYPE_MAP dictionary is currently created inside
receive_weights_scratch causing it to be rebuilt on every call; move the
_DTYPE_MAP definition to module scope (as a top-level constant) so it’s
constructed once and can be reused by receive_weights_scratch and other
functions (e.g., a future receive_weights) that need dtype validation; update
references in receive_weights_scratch to use the module-level _DTYPE_MAP.
- Around line 230-232: The loops in receive_weights and
receive_weights_from_metadata are directly accessing the private attribute
NixlTransferManager._tensors; instead add a public accessor on
NixlTransferManager (e.g., get_registered_tensor(name) or a registered_tensors
property) that returns either a single tensor by name or an iterable/dict of
registered tensors, then update receive_weights and
receive_weights_from_metadata to call that public API (use
get_registered_tensor(td.name) or check membership via registered_tensors)
instead of reading _tensors directly so the implementation is decoupled from
NixlTransferManager's internals.
- Around line 292-329: Remove the unused scratch_shapes map (created near
scratch_shapes: dict[...] and populated with scratch_shapes[td.name]) since it
is never read; then, before reshaping in the loop that yields name, tensor, add
an explicit assertion that the caller-provided tensor_shapes[name] product
equals the allocated tensor.numel() (compute expected numel from source_tensors'
td.size and elem_size or use tensor.numel()) and raise a clear error message
referencing the tensor name and mismatched sizes if not equal, then call
tensor.view(tensor_shapes[name]); keep usages of scratch_tensors,
source_tensors, td, elem_size, and the RDMA call self._nixl.receive_from_source
unchanged.
🪄 Autofix (Beta)
Fix all unresolved CodeRabbit comments on this PR:
- Push a commit to this branch (recommended)
- Create a new PR with the fixes
ℹ️ Review info
⚙️ Run configuration
Configuration used: Path: .coderabbit.yaml
Review profile: CHILL
Plan: Pro
Run ID: 6893e61e-3a17-4a4f-a98a-bd66dc776e4d
⛔ Files ignored due to path filters (5)
docs/slides/diagram-architecture.svgis excluded by!**/*.svgdocs/slides/diagram-component-stack.svgis excluded by!**/*.svgdocs/slides/diagram-framework-comparison.svgis excluded by!**/*.svgdocs/slides/diagram-rl-loop-bottleneck.svgis excluded by!**/*.svgdocs/slides/diagram-transfer-flow.svgis excluded by!**/*.svg
📒 Files selected for processing (8)
docs/MX_RL_OVERVIEW.mddocs/RL/PRIMERL_MX_OVERVIEW.mddocs/RL/VERL_MX_OVERVIEW.mddocs/slides/mx-rl-integration-slides.htmldocs/slides/mx-rl-integration-slides.mdmodelexpress_client/python/modelexpress/__init__.pymodelexpress_client/python/modelexpress/refit_receiver.pymodelexpress_client/python/modelexpress/training_publisher.py
| ``` | ||
| Trainer GPU MX Server (gRPC + Redis) Inference GPU |
There was a problem hiding this comment.
Add language identifiers to fenced code blocks.
markdownlint (MD040) flags the opening fences at lines 27, 365, 376, and 388. Use text or an appropriate identifier so syntax highlighting and link checkers behave predictably.
Proposed minimal fix
-```
+```text
Trainer GPU MX Server (gRPC + Redis) Inference GPUApply the same at lines 365, 376, 388.
Also applies to: 365-366, 376-377, 388-389
🧰 Tools
🪛 markdownlint-cli2 (0.22.1)
[warning] 27-27: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/MX_RL_OVERVIEW.md` around lines 27 - 28, Add a language identifier
(e.g., text) to the fenced code blocks that currently open with ``` so
markdownlint MD040 is satisfied; specifically update the fence before the block
containing "Trainer GPU MX Server (gRPC + Redis)
Inference GPU" and the other two similar fenced blocks in the document (the ones
that mirror that header layout) to use ```text instead of ``` so they become
```text and preserve the inner content unchanged.
| ``` | ||
| Trainer GPU MX Server (gRPC + Redis) Inference GPU | ||
| │ │ │ | ||
| │ 1. optimizer.step() │ │ | ||
| │ (weights updated in VRAM) │ │ | ||
| │ │ │ | ||
| │ 2. publish_weights() │ │ | ||
| │──── tensor addrs + NIXL ──────►│ │ | ||
| │ metadata via gRPC │ │ | ||
| │ │ 3. poll_for_source() │ | ||
| │ │◄──── "any new weights?" ───────────│ | ||
| │ │ │ | ||
| │ │ 4. get_metadata() │ | ||
| │ │──── addrs + NIXL conn info ───────►│ | ||
| │ │ │ | ||
| │ 5. NIXL RDMA READ │ │ | ||
| │◄══════════════ GPU-to-GPU data transfer ═══════════════════════════►│ | ||
| │ (inference GPU reads from trainer GPU, CPU not involved) │ | ||
| │ │ │ | ||
| │ │ 6. model.load_weights() │ | ||
| │ │ (inference applies weights) │ | ||
| ``` |
There was a problem hiding this comment.
Replace the ASCII sequence diagram with mermaid.
The architecture flow here is ASCII art, but mermaid is used elsewhere in this same file (the PRIME-RL and verl sections render cleanly). A mermaid sequenceDiagram gives the same information with better rendering across GitHub/VS Code and respects the repo convention.
As per coding guidelines: "Use mermaid diagrams instead of ASCII art in markdown files."
🧰 Tools
🪛 markdownlint-cli2 (0.22.1)
[warning] 27-27: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/MX_RL_OVERVIEW.md` around lines 27 - 48, Replace the ASCII art sequence
diagram with a mermaid sequenceDiagram block that preserves the same
participants and ordered steps: Trainer GPU (call out optimizer.step()), MX
Server (gRPC + Redis) (publish_weights(), get_metadata()), and Inference GPU
(poll_for_source(), model.load_weights()); map steps 1–6 (optimizer.step,
publish_weights, poll_for_source, get_metadata, NIXL RDMA READ / GPU-to-GPU
transfer, model.load_weights) into mermaid messages and notes so rendering
matches the original flow and retains the NIXL RDMA READ/GPU-to-GPU data
transfer comment; ensure the fenced block uses ```mermaid to render in Markdown
and replace the entire ASCII diagram block.
| 1. Trainer runs `optimizer.step()`, gathers FSDP2 shards, calls `MxTrainingPublisher.publish_weights()` | ||
| 2. Orchestrator detects new weights (via filesystem marker), tells inference to update | ||
| 3. Inference calls `MxRefitReceiver.receive_weights_scratch()` — NIXL RDMA pulls weights from trainer GPU | ||
| 4. Scratch tensors reshaped using safetensors header shapes, fed through `model.load_weights()` |
There was a problem hiding this comment.
"training framework" string is hardcoded to prime_rl in MxTrainingPublisher.
This overview doc advertises MxTrainingPublisher as framework-agnostic (used by both PRIME-RL and verl integrations, per §verl and §PRIME-RL below), but modelexpress_client/python/modelexpress/training_publisher.py line 129 hardcodes "training_framework": "prime_rl" in _build_identity. The verl integration will publish with the wrong tag.
Consider threading training_framework through initialize() so both backends set it correctly. (Root-cause comment also relevant to the verl overview doc.)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/MX_RL_OVERVIEW.md` around lines 174 - 177, MxTrainingPublisher currently
hardcodes "training_framework": "prime_rl" inside _build_identity causing
non-prime integrations to publish wrong tags; modify MxTrainingPublisher so
training_framework is passed through initialize() (or constructor) and stored as
an instance attribute, update _build_identity to read self.training_framework
instead of the hardcoded string, and ensure both PRIME-RL and verl callers set
the appropriate framework value when calling initialize().
| ``` | ||
| t=0 Trainer publishes version N. | ||
| Sources for version N: {Trainer}. | ||
| MX Server DAG: Trainer ──→ (R0..R11 all polling) | ||
|
|
||
| t=t0 Trainer → R0 RDMA completes first. | ||
| R0 calls publish_rollout_source(version=N). | ||
| Sources: {Trainer, R0}. | ||
| MX Server DAG: Trainer ──→ (R1..R11 polling) | ||
| │ | ||
| └─ R0 ──→ (next pollers can choose R0 or Trainer) | ||
|
|
||
| t=t1 R1 and R2 pull in parallel from {Trainer, R0} (server load-balances). | ||
| Both finalize; publish_rollout_source(). | ||
| Sources: {Trainer, R0, R1, R2}. | ||
| Effective outbound: 4 NICs serving R3..R11. | ||
|
|
||
| t=t2 R3..R6 finalize from {Trainer, R0, R1, R2}. | ||
| Sources: {Trainer, R0..R6}. | ||
| Effective outbound: 8 NICs serving R7..R11. | ||
|
|
||
| t=t3 R7..R11 finalize. | ||
| All 12 rollouts hold version N. | ||
| ``` |
There was a problem hiding this comment.
Convert the ASCII flow diagrams to mermaid.
The "DAG buildup over time" timeline (L218-241) and the "Before / After" per-rank publishing diagrams (L366-383) are ASCII art. Mermaid graph or sequenceDiagram renders this content consistently and matches the rest of this document (§1 and §2 already use mermaid). The smaller state-schema blocks (L255-259, L409-412) and the deployment-shape block (L511-530) are tabular/structural and acceptable as fenced text — but they need a language identifier (text) per markdownlint MD040.
As per coding guidelines: "Use mermaid diagrams instead of ASCII art in markdown files."
Also applies to: 366-383
🧰 Tools
🪛 markdownlint-cli2 (0.22.1)
[warning] 218-218: Fenced code blocks should have a language specified
(MD040, fenced-code-language)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/RL/PRIMERL_MX_OVERVIEW.md` around lines 218 - 241, The ASCII art
timeline titled "DAG buildup over time" and the "Before / After per-rank
publishing" diagrams should be converted to mermaid diagrams (use graph LR or
sequenceDiagram as appropriate) so they render consistently with the rest of the
document; replace the ASCII blocks at the "DAG buildup over time" and the
per-rank publishing sections with mermaid fenced code blocks (```mermaid ...
```), modeling nodes like Trainer, R0..R11 and the publish/pull arrows and
timeline steps t=0..t3; also add language identifier `text` to the smaller
state-schema and deployment-shape fenced blocks (the state-schema and
deployment-shape blocks noted in the comment) to satisfy markdownlint MD040.
Ensure labeling and arrow directions mirror the original ASCII semantics and
keep node names (Trainer, R0..R11, t=0..t3) unchanged.
| | Mode | Ray actors | Status for MX | | ||
| |------|-----------|--------------| | ||
| | **Hybrid (colocated)** | `WorkerDict` does both training and rollout | ❌ No `execute_checkpoint_engine` method — `CheckpointEngineManager` fails | | ||
| | **Standalone (disaggregated)** | Trainer uses `WorkerDict`, rollout uses `CheckpointEngineWorker` | ✅ Full CE lifecycle available | | ||
|
|
||
| This is a verl framework constraint, not an MX constraint — the built-in `nixl` and `nccl` engines have the same requirement. Our prototype runs in standalone mode on 2 nodes. |
There was a problem hiding this comment.
Remove emoji markers from the deployment-mode table.
The red-cross and green-check on lines 312 and 313 are emojis inside a markdown file. As per coding guidelines: "No emojis in code or comments" (applies to **/*.md). Replace with plain words — they convey the same status without the rendering inconsistencies that emojis cause across editors and PDF exports.
Proposed fix
-| **Hybrid (colocated)** | `WorkerDict` does both training and rollout | ❌ No `execute_checkpoint_engine` method — `CheckpointEngineManager` fails |
-| **Standalone (disaggregated)** | Trainer uses `WorkerDict`, rollout uses `CheckpointEngineWorker` | ✅ Full CE lifecycle available |
+| **Hybrid (colocated)** | `WorkerDict` does both training and rollout | Not supported — no `execute_checkpoint_engine` method; `CheckpointEngineManager` fails |
+| **Standalone (disaggregated)** | Trainer uses `WorkerDict`, rollout uses `CheckpointEngineWorker` | Supported — full CE lifecycle available |📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| | Mode | Ray actors | Status for MX | | |
| |------|-----------|--------------| | |
| | **Hybrid (colocated)** | `WorkerDict` does both training and rollout | ❌ No `execute_checkpoint_engine` method — `CheckpointEngineManager` fails | | |
| | **Standalone (disaggregated)** | Trainer uses `WorkerDict`, rollout uses `CheckpointEngineWorker` | ✅ Full CE lifecycle available | | |
| This is a verl framework constraint, not an MX constraint — the built-in `nixl` and `nccl` engines have the same requirement. Our prototype runs in standalone mode on 2 nodes. | |
| | Mode | Ray actors | Status for MX | | |
| |------|-----------|--------------| | |
| | **Hybrid (colocated)** | `WorkerDict` does both training and rollout | Not supported — no `execute_checkpoint_engine` method; `CheckpointEngineManager` fails | | |
| | **Standalone (disaggregated)** | Trainer uses `WorkerDict`, rollout uses `CheckpointEngineWorker` | Supported — full CE lifecycle available | | |
| This is a verl framework constraint, not an MX constraint — the built-in `nixl` and `nccl` engines have the same requirement. Our prototype runs in standalone mode on 2 nodes. |
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/RL/VERL_MX_OVERVIEW.md` around lines 310 - 315, Remove the emoji markers
in the deployment-mode table and replace them with plain words; specifically
update the status cells that currently show ❌ and ✅ to textual statuses like
"Not supported" and "Supported" (or similar) for the rows referencing WorkerDict
/ execute_checkpoint_engine / CheckpointEngineManager and Trainer /
CheckpointEngineWorker, keeping the explanatory sentence about this being a verl
framework constraint and retaining the references to the built-in engines `nixl`
and `nccl` and the prototype note.
| <div class="anim d3" style="margin-top:16px;text-align:center;"> | ||
| <p style="font-size:13px;color:var(--text-dim);font-family:'Space Mono',monospace;letter-spacing:1px;margin-bottom:8px;"> | ||
| [ INSERT DIAGRAM: diagram-architecture.svg ] | ||
| </p> |
There was a problem hiding this comment.
Unreplaced diagram placeholder is visible on the rendered slide.
The literal text [ INSERT DIAGRAM: diagram-architecture.svg ] will show up on Slide 3 because the inline "fallback" architecture diagram just below is rendered instead of (not as a replacement for) the missing SVG. Either embed the SVG, drop the placeholder text, or hide it with CSS so the audience doesn't see an authoring note.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/slides/mx-rl-integration-slides.html` around lines 432 - 435, The slide
contains an authoring placeholder text "[ INSERT DIAGRAM:
diagram-architecture.svg ]" inside the <p> element within the div.anim d3 block
which ends up rendering to viewers; remove or replace that placeholder by either
embedding the actual SVG (replace the <p> placeholder with the inline SVG markup
or an <img src="diagram-architecture.svg">), or hide the placeholder via CSS
(target the .anim.d3 p fallback paragraph and set display:none) so the fallback
architecture diagram remains visible without showing the authoring note.
| ``` | ||
| | Rollout (40%) | Rew | Train (20%) | ██ REFIT (30%) ██ | | ||
| ▲ BOTTLENECK ▲ | ||
| ``` | ||
|
|
||
| > Up to 30–40% of wall-clock for 70B+ models | ||
|
|
||
| ### Current refit latency (70B-class model, multi-node) | ||
|
|
||
| | Method | Latency | | ||
| |--------|---------| | ||
| | Filesystem (PRIME-RL) | ~20s+ | | ||
| | NCCL Broadcast (NeMo RL) | ~10s | | ||
| | ZMQ IPC (NeMo RL, co-located) | ~3-5s | | ||
| | **MX RDMA P2P (target)** | **~5s** | | ||
|
|
||
| --- | ||
|
|
||
| ## Slide 3 — The Solution: ModelExpress for Training→Inference Refit | ||
|
|
||
| Extend MX from inference-to-inference P2P to the training→inference boundary. Training workers register updated weights with NIXL, publish metadata to the MX Server, and RDMA-WRITE directly into inference GPU memory — bypassing CPU, disk, and collective overheads. | ||
|
|
||
| ### High-level data flow | ||
|
|
||
| ``` | ||
| Training Workers MX Server Inference Workers | ||
| (FSDP2 / Megatron) (gRPC + Redis/CRD) (vLLM / SGLang) | ||
|
|
||
| WeightExtractor ──gRPC──► Metadata Coord ◄──gRPC── MxRefitReceiver | ||
| MxTrainingPublisher Version Tracking NIXL Agent | ||
| NIXL Agent | ||
| │ ▲ | ||
| └══════════════ RDMA WRITE (GPU→GPU) ════════════════┘ | ||
| bypasses CPU & disk | ||
| ``` |
There was a problem hiding this comment.
ASCII diagrams in slides markdown should be mermaid.
Lines 29-32 (the bottleneck bar) and 53-63 (the three-column architecture) are ASCII art. The repo already uses mermaid for equivalent diagrams (see docs/MX_RL_OVERVIEW.md and docs/RL/VERL_MX_OVERVIEW.md). Either convert these two blocks to mermaid or rely on the referenced SVGs (diagram-architecture.svg, diagram-rl-loop-bottleneck.svg) and drop the ASCII fallback.
As per coding guidelines: "Use mermaid diagrams instead of ASCII art in markdown files."
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@docs/slides/mx-rl-integration-slides.md` around lines 29 - 63, Replace the
ASCII art blocks (the bottleneck bar containing "Rollout (40%) | Rew | Train
(20%) | ██ REFIT (30%) ██" and the three-column architecture block containing
"Training Workers MX Server Inference Workers" with
mermaid diagrams (flowchart or sequence as appropriate) that represent the same
layout and include the component labels WeightExtractor, MxTrainingPublisher,
NIXL Agent, Metadata Coord, MxRefitReceiver, and NIXL Agent; alternatively
remove the ASCII fallback and embed the existing SVGs (diagram-architecture.svg
and diagram-rl-loop-bottleneck.svg) via standard markdown image links, ensuring
the slide uses mermaid or the SVGs per the repo guideline.
| def poll_for_source( | ||
| self, | ||
| model_name: str, | ||
| min_step: int | None = None, | ||
| status_filter: int = p2p_pb2.SOURCE_STATUS_READY, | ||
| timeout_seconds: float = 0, | ||
| ) -> SourceRef | None: | ||
| """Check the MX Server for a training source with updated weights. | ||
|
|
||
| Args: | ||
| model_name: Model name to filter on (must match publisher's identity). | ||
| min_step: If set, only return sources with ``training_step >= min_step``. | ||
| Defaults to ``current_step + 1`` to only find newer versions. | ||
| timeout_seconds: If > 0, poll repeatedly until a source is found | ||
| or timeout is reached. If 0, check once and return immediately. | ||
|
|
||
| Returns: | ||
| A :class:`SourceRef` if a matching source was found, else *None*. | ||
| """ | ||
| if not self._initialized: | ||
| raise RuntimeError("Call initialize() before poll_for_source()") | ||
|
|
||
| if min_step is None: | ||
| min_step = self._current_step + 1 | ||
|
|
||
| deadline = time.perf_counter() + timeout_seconds | ||
|
|
||
| while True: | ||
| try: | ||
| response = self._client.list_sources( | ||
| status_filter=status_filter, | ||
| ) | ||
| except Exception as e: | ||
| logger.warning(f"list_sources failed: {e}") | ||
| if time.perf_counter() >= deadline: | ||
| return None | ||
| time.sleep(0.5) | ||
| continue | ||
|
|
||
| for instance in response.instances: | ||
| if instance.model_name != model_name: | ||
| continue | ||
|
|
||
| return SourceRef( | ||
| mx_source_id=instance.mx_source_id, | ||
| worker_id=instance.worker_id, | ||
| model_name=instance.model_name, | ||
| worker_rank=instance.worker_rank, | ||
| training_step=0, | ||
| ) | ||
|
|
||
| if time.perf_counter() >= deadline: | ||
| return None | ||
| time.sleep(0.5) |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
# Confirm the instance message shape and whether training_step is surfaced by ListSources.
fd -e proto | xargs rg -nP -C2 '\b(ListSources|SourceInstance|ListSourcesResponse|extra_parameters)\b'Repository: ai-dynamo/modelexpress
Length of output: 1928
🏁 Script executed:
# Get the complete SourceInstanceRef message definition
rg -A 20 'message SourceInstanceRef' modelexpress_common/proto/p2p.protoRepository: ai-dynamo/modelexpress
Length of output: 697
🏁 Script executed:
# Read the exact code being reviewed
sed -n '119,172p' modelexpress_client/python/modelexpress/refit_receiver.py | cat -nRepository: ai-dynamo/modelexpress
Length of output: 2446
🏁 Script executed:
# Check what fields SourceRef expects and how training_step is used
rg -A 10 'class SourceRef' modelexpress_client/python/modelexpress/Repository: ai-dynamo/modelexpress
Length of output: 1006
🏁 Script executed:
# Check GetMetadataResponse to see if training_step is there
rg -A 30 'message GetMetadataResponse' modelexpress_common/proto/p2p.protoRepository: ai-dynamo/modelexpress
Length of output: 825
🏁 Script executed:
# Check WorkerMetadata to see if training_step is there
rg -A 20 'message WorkerMetadata' modelexpress_common/proto/p2p.protoRepository: ai-dynamo/modelexpress
Length of output: 664
🏁 Script executed:
# Check SourceIdentity definition
rg -A 15 'message SourceIdentity' modelexpress_common/proto/p2p.protoRepository: ai-dynamo/modelexpress
Length of output: 483
🏁 Script executed:
# Get more of SourceIdentity to find extra_parameters
rg -A 25 'message SourceIdentity' modelexpress_common/proto/p2p.protoRepository: ai-dynamo/modelexpress
Length of output: 767
🏁 Script executed:
# Get the complete SourceIdentity message
rg -A 35 'message SourceIdentity' modelexpress_common/proto/p2p.protoRepository: ai-dynamo/modelexpress
Length of output: 1074
🏁 Script executed:
# Get complete ListSourcesRequest and ListSourcesResponse
rg -A 15 'message ListSourcesRequest' modelexpress_common/proto/p2p.protoRepository: ai-dynamo/modelexpress
Length of output: 620
🏁 Script executed:
# Check the receive_weights methods to see where training_step is set
rg -B 3 -A 3 'self._current_step = source.training_step' modelexpress_client/python/modelexpress/refit_receiver.pyRepository: ai-dynamo/modelexpress
Length of output: 498
🏁 Script executed:
# Check the publisher side to see how training_step is communicated
rg -B 5 -A 5 'training_step' modelexpress_client/python/modelexpress/training_publisher.py | head -50Repository: ai-dynamo/modelexpress
Length of output: 818
🏁 Script executed:
# Check if there's any other place where training_step might be obtained
rg 'training_step' modelexpress_client/python/modelexpress/refit_receiver.py | head -20Repository: ai-dynamo/modelexpress
Length of output: 498
poll_for_source hardcodes training_step=0 and never filters on min_step — breaks RL refit version tracking.
Two connected bugs:
-
Line 167:
training_step=0is hardcoded in the returnedSourceRef, independent of what the publisher advertised. Downstream,receive_weights/receive_weights_scratch/receive_weights_from_metadatasetself._current_step = source.training_step— socurrent_stepis permanently 0. -
Lines 141-142 and 158-168:
min_stepis computed but never applied to filter the loop — the only predicate isinstance.model_name != model_name. The docstring promise "only return sources withtraining_step >= min_step" is unimplemented. Combined with (1), a receiver can repeatedly pick up the same staleREADYsource.
However, the proto design issue is blocking: The publisher stores training_step in SourceIdentity.extra_parameters["training_step"], but ListSourcesResponse.instances contains only SourceInstanceRef messages, which do not expose extra_parameters or any training step field.
To fix this: either add training_step (or source_identity) to SourceInstanceRef in the proto, or have ListSources return SourceIdentity alongside each instance so the receiver can extract and filter on the step.
🧰 Tools
🪛 Ruff (0.15.11)
[warning] 151-151: Do not catch blind exception: Exception
(BLE001)
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelexpress_client/python/modelexpress/refit_receiver.py` around lines 119 -
172, poll_for_source currently ignores min_step and returns SourceRef with
training_step=0 causing current_step to stay 0; update poll_for_source (in
refit_receiver.py) to filter instances by training_step >= min_step and populate
SourceRef.training_step from the publisher-advertised value instead of
hardcoding 0. Because the ListSourcesResponse.instances (SourceInstanceRef)
doesn't expose training_step, change the protobuf/api to include training_step
(or return SourceIdentity alongside each instance) so the receiver can read
extra_parameters["training_step"] (or the new training_step field) and use it in
the instance filter and when constructing SourceRef.
| def _build_identity(self, step: int) -> p2p_pb2.SourceIdentity: | ||
| """Build a SourceIdentity proto with the current training step.""" | ||
| return p2p_pb2.SourceIdentity( | ||
| extra_parameters={ | ||
| "training_step": str(step), | ||
| "training_framework": "prime_rl", | ||
| }, | ||
| **self._identity_kwargs, | ||
| ) |
There was a problem hiding this comment.
Hardcoded training_framework = "prime_rl" leaks into verl-published metadata.
The verl integration overview (docs/RL/VERL_MX_OVERVIEW.md) lists MxTrainingPublisher as the trainer-side primitive, and this hardcoded value will appear in SourceIdentity.extra_parameters for verl workloads too. Thread it through initialize():
Proposed fix
def initialize(
self,
model_name: str,
tensor_parallel_size: int = 1,
pipeline_parallel_size: int = 1,
expert_parallel_size: int = 1,
dtype: str = "bfloat16",
+ training_framework: str = "unknown",
) -> None:
@@
self._model_name = model_name
+ self._training_framework = training_framework
self._identity_kwargs = dict(
@@
def _build_identity(self, step: int) -> p2p_pb2.SourceIdentity:
"""Build a SourceIdentity proto with the current training step."""
return p2p_pb2.SourceIdentity(
extra_parameters={
"training_step": str(step),
- "training_framework": "prime_rl",
+ "training_framework": self._training_framework,
},
**self._identity_kwargs,
)🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelexpress_client/python/modelexpress/training_publisher.py` around lines
124 - 132, The hardcoded training_framework value in _build_identity leaks into
published metadata; update the class to accept a training_framework in
initialize (store it as self._training_framework) and replace the literal
"prime_rl" in _build_identity with that attribute so
SourceIdentity.extra_parameters uses self._training_framework; ensure
initialize's callers/constructors that create MxTrainingPublisher pass the
intended framework or default appropriately and that _build_identity continues
to return a p2p_pb2.SourceIdentity with extra_parameters including
"training_step" and the new "training_framework" value.
| def publish_layer( | ||
| self, | ||
| layer_state_dict: dict[str, torch.Tensor], | ||
| layer_idx: int, | ||
| step: int, | ||
| worker_rank: int = 0, | ||
| ) -> str: | ||
| """Publish a single layer's weights to MX Server. | ||
|
|
||
| Designed for PRIME-RL's layer-by-layer streaming pattern where | ||
| ``filter_state_dict_by_layers()`` yields one layer at a time. | ||
|
|
||
| Layer tensors are registered with NIXL (overwriting previous | ||
| registration), and metadata is published to the MX Server. The | ||
| inference side accumulates all layers before loading. | ||
|
|
||
| Args: | ||
| layer_state_dict: Parameter name -> tensor for this layer. | ||
| layer_idx: Layer index (-1 for non-layer weights like embeddings). | ||
| step: Current training step. | ||
| worker_rank: GPU rank of this worker. | ||
|
|
||
| Returns: | ||
| The ``mx_source_id`` assigned by the server. | ||
| """ | ||
| if not self._initialized: | ||
| raise RuntimeError("Call initialize() before publish_layer()") | ||
|
|
||
| self._nixl.register_tensors(layer_state_dict) | ||
| metadata = self._nixl.nixl_metadata | ||
| descriptors = self._nixl.tensor_descriptors | ||
|
|
||
| identity = self._build_identity(step) | ||
| identity.extra_parameters["layer_idx"] = str(layer_idx) | ||
|
|
||
| worker_meta = p2p_pb2.WorkerMetadata( | ||
| worker_rank=worker_rank, | ||
| nixl_metadata=metadata, | ||
| tensors=self._build_tensor_protos(descriptors), | ||
| status=p2p_pb2.SOURCE_STATUS_INITIALIZING, | ||
| agent_name=self._agent_name, | ||
| ) | ||
|
|
||
| self._mx_source_id = self._client.publish_metadata( | ||
| identity=identity, | ||
| worker=worker_meta, | ||
| worker_id=self._worker_id, | ||
| ) | ||
| logger.debug( | ||
| f"Published layer {layer_idx} ({len(layer_state_dict)} tensors) " | ||
| f"for step {step}" | ||
| ) | ||
| return self._mx_source_id |
There was a problem hiding this comment.
🧩 Analysis chain
🏁 Script executed:
#!/bin/bash
rg -nP --type=py -C3 '\b(publish_weights|publish_layer)\s*\('Repository: ai-dynamo/modelexpress
Length of output: 4206
🏁 Script executed:
cat -n modelexpress_client/python/modelexpress/training_publisher.py | sed -n '148,203p'Repository: ai-dynamo/modelexpress
Length of output: 2540
🏁 Script executed:
cat -n modelexpress_client/python/modelexpress/training_publisher.py | sed -n '1,100p'Repository: ai-dynamo/modelexpress
Length of output: 4015
🏁 Script executed:
rg -nP --type=py '_registered' modelexpress_client/python/modelexpress/training_publisher.pyRepository: ai-dynamo/modelexpress
Length of output: 178
🏁 Script executed:
rg -nP --type=py '\._registered\s*=' modelexpress_client/python/modelexpress/training_publisher.pyRepository: ai-dynamo/modelexpress
Length of output: 141
🏁 Script executed:
cat -n modelexpress_client/python/modelexpress/training_publisher.py | sed -n '250,280p'Repository: ai-dynamo/modelexpress
Length of output: 1410
publish_layer can silently invalidate NIXL registration when interleaved with publish_weights.
publish_weights registers tensors once and caches via self._registered (L174). publish_layer calls register_tensors on every invocation (L232) for a different tensor set. If publish_weights is called after publish_layer, subsequent calls to publish_weights will skip re-registration because _registered remains True, even though NIXL now holds only the last layer's buffers instead of the full model.
Although the docstring and example usage suggest these methods are mutually exclusive per run (one or the other, not both), there is no runtime assertion preventing both from being called. Consider adding a mutual-exclusivity check at the start of each method, or resetting _registered in publish_layer.
🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed.
In `@modelexpress_client/python/modelexpress/training_publisher.py` around lines
204 - 256, publish_layer currently calls self._nixl.register_tensors(...) every
time but does not update/reset the self._registered flag used by
publish_weights, so if publish_layer and publish_weights are interleaved NIXL
can end up holding only a subset of tensors; fix by enforcing mutual exclusivity
or synchronizing registration state: at the start of publish_layer (and
similarly in publish_weights) assert or raise if the other mode was previously
used (check self._registered or a new mode flag), or after calling
self._nixl.register_tensors(...) update/reset self._registered (or set a
dedicated self._registration_mode flag) so publish_weights will re-register when
needed; reference publish_layer, publish_weights, self._registered,
register_tensors, and self._nixl when implementing the check/reset.
Rebased onto current main (was 3 weeks stale; resolved one trivial
__all__ merge conflict in modelexpress/__init__.py).
Python — correctness fixes:
1. refit_receiver.poll_for_source: was hardcoding training_step=0
on the returned SourceRef and never filtering on min_step despite
advertising both in the docstring. ListSourcesResponse instances
carry only SourceInstanceRef (no extra_parameters), so the actual
training_step lives on SourceIdentity in the publisher's metadata.
Now do a per-candidate get_metadata() lookup, parse training_step
from SourceIdentity.extra_parameters, and skip candidates whose
step is below the threshold or unparseable. Cost: extra gRPC
round-trip per candidate; can be removed once training_step is
surfaced on SourceInstanceRef directly.
2. training_publisher.initialize(): training_framework was
hardcoded to "prime_rl" in _build_identity, which mislabeled
verl-published sources. Now a parameter on initialize() (default
"unknown" so callers know to set it explicitly).
3. training_publisher publish_weights / publish_layer mutual
exclusivity: publish_layer registers fresh tensors every call but
publish_weights caches via self._registered, so interleaving the
two paths could leave NIXL holding only the most-recently-
registered tensor set. New self._publish_mode tracks which path
is in use; either method raises if the other was already used on
this publisher.
4. refit_receiver._DTYPE_MAP: lifted to module scope (was rebuilt
per call inside receive_weights_scratch).
Docs — content fixes:
5. VERL_MX_OVERVIEW.md deployment-mode table: replaced ❌ / ✅
emoji markers with plain text per repo "no emojis in markdown"
guideline.
6. PRIMERL_MX_OVERVIEW.md §3.9: fixed duplicate "byte-exact
byte-exact" → "byte-exact".
7. MD040: annotated 15 bare ``` fences across MX_RL_OVERVIEW.md,
PRIMERL_MX_OVERVIEW.md, VERL_MX_OVERVIEW.md,
PRIMERL_MX_NATIVE_DESIGN.md, mx-rl-integration-slides.md as
```text where they were carrying plain prose / ASCII layout.
8. ASCII → mermaid:
- MX_RL_OVERVIEW.md §Architecture: ASCII trainer/server/inference
swimlane → sequenceDiagram.
- PRIMERL_MX_OVERVIEW.md §3.2 DAG buildup: 5-phase ASCII timeline
→ flowchart with one subgraph per phase, plus a per-phase
bandwidth table.
- PRIMERL_MX_OVERVIEW.md §3.9 before/after: naive-allgather vs
overlay per-rank flow → side-by-side flowchart.
- Slide-deck ASCII bottleneck-bar / 3-column architecture
intentionally retained: those are CSS-styled visual fallbacks
for the SVGs, not Markdown rendering targets. The misleading
"[ INSERT DIAGRAM: diagram-architecture.svg ]" placeholder text
above the architecture fallback was removed.
No proto / server-side changes. The poll_for_source fix is the
proto-level workaround documented in the CodeRabbit review; the
forward-looking fix (adding training_step directly to
SourceInstanceRef so we don't need the per-candidate get_metadata)
is a follow-up.
Made-with: Cursor
d4d078c to
f0b7563
Compare
Presentation covering ModelExpress integration into RL post-training weight sync (refit) for NeMo RL, verl, and PRIME-RL. Includes HTML slideshow and standalone SVG diagrams for architecture, transfer flow, component stack, and framework comparison. Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Mirrors all 6 slides from the HTML presentation in plain markdown for easier viewing on GitHub and compatibility with Marp/Slidev. Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Training-side publisher registers updated model weights with NIXL and publishes metadata to the MX Server. Inference-side receiver discovers sources via ListSources, pulls weights via RDMA, and yields (name, tensor) pairs compatible with vLLM's load_weights(). Supports both all-at-once and layer-by-layer streaming patterns for PRIME-RL integration. Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
…hash mismatch) Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Tensor memory addresses don't change between optimizer steps, only values do. Calling register_memory every step accumulated descriptors, inflating the metadata blob from ~27 KB to 800+ KB and causing NIXL_ERR_NOT_ALLOWED on add_remote_agent. Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Allocates temporary GPU buffers matching the source's tensor layout, receives via NIXL RDMA, and yields (name, tensor) pairs in HF format. The caller's model.load_weights() handles name mapping and tensor fusion (e.g. HF q/k/v -> vLLM qkv_proj). Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
…ible with scratch buffers) Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
…ht reshaping Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
…rams Covers the MxCheckpointEngine design, Ray actor topology, and GB200 prototype results (10 steps, avg ~1.25s cross-node RDMA weight sync). Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
The section referenced recovery/ paths outside the ModelExpress repo that aren't accessible to external readers. Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
…og value - Rename §2 to frame MX as additive on verl's NIXL checkpoint engine - Document native nixl ring path positively; optional MX catalog + star - Add catalog benefits (balancing, multi-source, publish/retire, retention) - Fix RDMA READ source/destination wording; remove PRIME-RL references - Tone: prefer native nixl vs consider mx; align metrics cross-refs Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Companion to PRIMERL_MX_OVERVIEW.md (Path A). Documents an MX-shaped weight broadcast design for prime-rl that uses PI's NIXL transport as the data plane but exposes ModelExpress's traditional API surface (model-agnostic, server-mediated, scratch-buffer-default, cross-framework-portable) instead of PI's per-model conversion_specs + slot system. Key positioning: - Path A = strict overlay on PI's API. Smallest diff. In flight. - Path B = native MX shape. Larger diff. Staged design only. - Both ship as discriminator options on weight_broadcast.type (existing nixl + new mx coexist). Documents: - Why we'd consider B (per-model spec wall, KL drift inheritance, elastic shapes, cross-framework alignment). - File footprint (~600 LOC new, ~70 LOC modified). - Migration path (toml-only for users). - Pitch sequence for PI conversation. - Inflection points to pivot from A to B. Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Fix two diagrams in PRIMERL_MX_OVERVIEW.md that implied MX replaces PI #2326's entire control plane, when v0.1 is env-var-gated and only swaps SPG coordinator discovery. Component diagram: - Label every imported PI element "(PI, unchanged)" and every MX element "(MX overlay, env-var gated)" so reader can see the split. - Re-draw control-plane edges as dotted green (MX additions): publish_spg_coordinator (boot), mark_version_ready (per step), discover_spg_coordinator (boot), publish_rollout_source (optional). - Add the SPG 2-round all_gather_obj edge between trainer and rollout labeled as PI code — previously missing entirely, so readers could think MX alone wired up NIXL agents. - Data-plane edge labeled "trainer -> rollout recv buffer" to match PI's actual WRITE semantics instead of a generic bidirectional arrow. Timing diagram: - Wrap flow in three tinted bands: green boot-time MX discovery (the only v0.1 change) vs purple per-step SPG metadata rounds + RDMA WRITE (PI, unchanged). - Move discovery out of the per-step par block into a "once per run" boot-time block — register_coordinator / discover_spg_coordinator are init-time, not update-time ops. - Add the SPG 2-round all_gather_obj step that was missing: round 1 exchanges agent_meta + slot_layout + recv-buffer descriptors, round 2 exchanges per-slot xfer descriptors. - Relabel the RDMA step as "NIXL RDMA WRITE -> rollout's recv buffer" so the write direction is explicit. - Trigger for unpublish changed from vague "next iteration" to "async_level >= 1" — ties the mutability contract to the actual concurrency regime that needs it. - Add explicit legend calling out MX-added vs PI-unchanged. No design changes; docs-only clarification so the diagrams match what the overlay code actually does. Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Scenario A (PI NIXL direct refit on Qwen3-0.6B, 2 trainer ranks × 1 inference rank) completed all 20 RL training steps end-to-end on GB200 on April 24. Update the doc to reflect reality. Changes: - Status line: "Design complete ... Metrics below are TBD" → "Scenario A green end-to-end" with concrete numbers in the lead paragraph (596 MB/push, 310 slots, 100% success, draft PR link). - §2 observed per-step timing table: replace PI-reported 12-node projections with measured scenario A numbers (20/20 steps, 5.1 s avg, 596 MB bucket, 310 slots published, rank-0 writes all 310 / rank-1 writes 197). B and C cells remain "pending" until the next session flips the env vars. - §2 caveat added explaining that throughput comparison vs PI's prod numbers is distorted by our SDPA fallback (ARM64 image ships a flash_attn import stub; real kernels are a P1 follow-up). NIXL transfer itself is unaffected; step-time inflation is on the training-compute side. - §4.5 metrics matrix: populate scenario A column with measured values; mark B/C/D/E as "pending" rather than "TBD" to signal the scenarios are scoped + instrumented, just not yet run. - §4.6 results summary: rewrite from "to be written after the benchmark run" to a concrete list of what Scenario A proved (foundation validated, per-rank sharding-aware works as PI designed, overlay is structurally correct). Add pointers to the nine blocker fixes documented in OVERLAY_PR_EXECUTION_STATE.md. Keep the B/C/D/E expectations section framed as "next-session targets" so readers know what to expect the doc to grow into. No diagram changes; April 24 timing + component diagram fixes from 461be85 remain the current shape. Made-with: Cursor Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Path A overlay scenarios A, B, and C all completed 20/20 training
steps on GB200. Update §2, §4.5, §4.6 with the measured numbers and
update §1's status line accordingly. Diagrams unchanged.
§2 (timing diagram + observed timing table):
- Replace pending B/C cells with measured values.
- Wire BW per trainer NIC: 7.82-8.84 GB/s (avg ~8.1) — exceeds
PI's reported ~7.5 GB/s prod target.
- Aggregate net BW: 35-39 GB/s rank 0 + rank 1 combined.
- Per-push breakdown (scenario C): convert 60-67 ms + post+wait
15-16 ms + barrier 1.2-1.6 ms ≈ 80 ms total for 596 MB.
- Add the pipeline-replication catalog state output (4 sources
incl. rollout-source-0-*) as the empirical proof the MX-side
protocol works.
§4.5 metrics matrix:
- All A/B/C cells now have measured values; D and E flagged as
deferred (with reasons).
- 4.5.2 throughput row gains the wire/net BW measurements.
§4.6 results summary rewritten:
- A: foundation validated, nine blockers resolved.
- B: MX-mediated discovery validated (single source_id shared
across all participants), data-path parity with A. Documents
the metadata_endpoint→nixl_metadata workaround we landed
during this session.
- C: pipeline-replication catalog entry confirmed; per-push wire/
net BW measured. Honestly notes the bandwidth-amplification
benefit isn't shown end-to-end yet (gated on PI-side dynamic
SPG world_size for elastic mid-run join).
- D and E: explicitly deferred with rationale (D needs a drift
reproducer to be valuable; E gates on dynamic SPG).
§1 status line: "scenarios A, B, and C all green on GB200" replaces
the "scenario A green" wording.
Net for the PR-on-PR: A and B are the strongest evidence (overlay
is additive without regressing data path). C's catalog entry plus
the measured per-push BW round out the picture. D and E are honest
follow-up axes.
Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Rebased onto current main (was 3 weeks stale; resolved one trivial
__all__ merge conflict in modelexpress/__init__.py).
Python — correctness fixes:
1. refit_receiver.poll_for_source: was hardcoding training_step=0
on the returned SourceRef and never filtering on min_step despite
advertising both in the docstring. ListSourcesResponse instances
carry only SourceInstanceRef (no extra_parameters), so the actual
training_step lives on SourceIdentity in the publisher's metadata.
Now do a per-candidate get_metadata() lookup, parse training_step
from SourceIdentity.extra_parameters, and skip candidates whose
step is below the threshold or unparseable. Cost: extra gRPC
round-trip per candidate; can be removed once training_step is
surfaced on SourceInstanceRef directly.
2. training_publisher.initialize(): training_framework was
hardcoded to "prime_rl" in _build_identity, which mislabeled
verl-published sources. Now a parameter on initialize() (default
"unknown" so callers know to set it explicitly).
3. training_publisher publish_weights / publish_layer mutual
exclusivity: publish_layer registers fresh tensors every call but
publish_weights caches via self._registered, so interleaving the
two paths could leave NIXL holding only the most-recently-
registered tensor set. New self._publish_mode tracks which path
is in use; either method raises if the other was already used on
this publisher.
4. refit_receiver._DTYPE_MAP: lifted to module scope (was rebuilt
per call inside receive_weights_scratch).
Docs — content fixes:
5. VERL_MX_OVERVIEW.md deployment-mode table: replaced ❌ / ✅
emoji markers with plain text per repo "no emojis in markdown"
guideline.
6. PRIMERL_MX_OVERVIEW.md §3.9: fixed duplicate "byte-exact
byte-exact" → "byte-exact".
7. MD040: annotated 15 bare ``` fences across MX_RL_OVERVIEW.md,
PRIMERL_MX_OVERVIEW.md, VERL_MX_OVERVIEW.md,
PRIMERL_MX_NATIVE_DESIGN.md, mx-rl-integration-slides.md as
```text where they were carrying plain prose / ASCII layout.
8. ASCII → mermaid:
- MX_RL_OVERVIEW.md §Architecture: ASCII trainer/server/inference
swimlane → sequenceDiagram.
- PRIMERL_MX_OVERVIEW.md §3.2 DAG buildup: 5-phase ASCII timeline
→ flowchart with one subgraph per phase, plus a per-phase
bandwidth table.
- PRIMERL_MX_OVERVIEW.md §3.9 before/after: naive-allgather vs
overlay per-rank flow → side-by-side flowchart.
- Slide-deck ASCII bottleneck-bar / 3-column architecture
intentionally retained: those are CSS-styled visual fallbacks
for the SVGs, not Markdown rendering targets. The misleading
"[ INSERT DIAGRAM: diagram-architecture.svg ]" placeholder text
above the architecture fallback was removed.
No proto / server-side changes. The poll_for_source fix is the
proto-level workaround documented in the CodeRabbit review; the
forward-looking fix (adding training_step directly to
SourceInstanceRef so we don't need the per-candidate get_metadata)
is a follow-up.
Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
f0b7563 to
4c7e1df
Compare
Adds NIXL_COMPRESSION_STUDY.md to help the NIXL nvCOMP compression team reproduce our RL weight-transfer payloads using our validated PRIME-RL and verl workflows. Three paths documented: 1. Pre-captured data (fastest) — pointer to our existing Qwen2.5-1.5B data package (model.safetensors + pre/post RL weights + deltas + KV cache, captured from live GB200 deployment). 2. End-to-end reproduction on GB200 via the PRIME-RL overlay (PR PrimeIntellect-ai/prime-rl#2343) — deploy scenario A, exec into trainer pod, capture state_dict + simulate one RL step + dump KV cache. Step-by-step with kubectl commands. 3. Reproduction via verl MxCheckpointEngine (PR ai-dynamo/modelexpress #252) — same tensor content, different transport path. Also covers: compression-relevant properties table, per-tensor layout for Qwen3-0.6B and Qwen2.5-1.5B, delta analysis notes (BF16 deltas mostly zero at RL learning rates; FP32 diffs are the meaningful analysis target), NIXL integration point for nvCOMP (transparent — compress/decompress at the NIXL layer, no MX or framework changes), and a model-size scaling table for larger captures. Signed-off-by: Kavin Krishnan <kavink@nvidia.com> Made-with: Cursor
The pre-captured Qwen2.5-1.5B data package referenced in Option 1 of NIXL_COMPRESSION_STUDY.md isn't in this repo (binary tensors at GB scale aren't appropriate to commit) and the path I had previously shown was an internal local checkout. Replace with explicit "request from kavink@nvidia.com" framing and call out the appropriate channels (NV S3, internal share, or direct upload to eschmidt@nvidia.com per the original ask). Add the total package size (~14 GB) so the NIXL team knows what to expect bandwidth-wise. Update the "larger models" cross-reference accordingly. Signed-off-by: Kavin Krishnan <kavink@nvidia.com> Made-with: Cursor
Adds the two scripts that produced the Qwen2.5-1.5B data package we
referenced in NIXL_COMPRESSION_STUDY.md so the NIXL team can reproduce
captures themselves on different models / sequence lengths / clusters
without going through us as a manual relay.
New files:
docs/RL/scripts/capture_weights_and_kv.py
Standalone — any HF model, any host (CPU or single GPU), no
cluster / RL framework needed. CLI flags for model, dtype,
device, output dir, weights/KV-only modes, KV seq_len.
docs/RL/scripts/capture_on_pod.py
Inside-a-running-RL-pod variant. Generalized vs the original
Qwen2.5-1.5B-only capture: --model, --out, --kv-seq-len, --lr
flags. Captures pre/post RL weights + simulated AdamW step
delta + KV cache in one pass. Produces the four-directory
layout (weights_pre_rl/, weights_post_rl/, weight_deltas/,
kv_cache/) we shipped to the compression team.
docs/RL/scripts/README.md
Quick reference for both scripts: when to use each, complete
CLI examples, output layout, the BF16-deltas-are-mostly-zero
note + FP32 analysis snippet, pointer back to the main study
doc.
Updated:
docs/RL/NIXL_COMPRESSION_STUDY.md
Option 2 now points at scripts/capture_on_pod.py with kubectl
cp + exec invocation instead of an inlined heredoc Python
block. Added Option-2-Step-4 ("standalone capture without a
running RL deployment") pointing at capture_weights_and_kv.py
for users who don't want to deploy the full overlay.
The original on-disk capture scripts in our internal recovery
directory are unchanged; this just publishes a generalized,
flag-driven version of each into the public docs tree.
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Made-with: Cursor
Two changes to NIXL_COMPRESSION_STUDY.md: 1. Add a component-view mermaid diagram at the top showing where the compression-target tensors actually live (RL refit edge between trainer and inference NIXL agents; KV cache edge between prefill and decode), with green nodes / edges marking the compression surface and purple marking RL-stack infrastructure that wouldn't change if nvCOMP slots into the NIXL layer transparently. 2. Drop GKE/cluster-specific assumptions. Previously Option 2 named a specific GKE node pool, namespace, registry, and tsh auth flow as prerequisites; now it just says "a GB200 cluster (ARM64) with at least 2 nodes, container runtime, RDMA-capable interconnect". The K8s manifests are flagged as examples that need light edits (ns, node selectors, registry, RDMA network annotations) per cluster. Hardcoded "kavin" namespace replaced with $NS=<your-namespace> throughout the kubectl commands so a copy-paste of the recipe works on any cluster. The capture flow itself was already cluster-agnostic — these edits just stop the doc reading like it's only reproducible on our exact GKE shape. Signed-off-by: Kavin Krishnan <kavink@nvidia.com> Made-with: Cursor
Removes both references to eschmidt@nvidia.com from NIXL_COMPRESSION_STUDY.md so the guide reads as a general team-facing doc rather than addressed at one inbox. Audience line now just says "NIXL compression team"; Option 1 channel list trims "direct upload to your eschmidt@nvidia.com inbox per the original request" down to "direct upload" — same channel options, no person-specific routing. Single contact for the data package remains kavink@nvidia.com. Signed-off-by: Kavin Krishnan <kavink@nvidia.com> Made-with: Cursor
Rebased onto current main (was 3 weeks stale; resolved one trivial
__all__ merge conflict in modelexpress/__init__.py).
Python — correctness fixes:
1. refit_receiver.poll_for_source: was hardcoding training_step=0
on the returned SourceRef and never filtering on min_step despite
advertising both in the docstring. ListSourcesResponse instances
carry only SourceInstanceRef (no extra_parameters), so the actual
training_step lives on SourceIdentity in the publisher's metadata.
Now do a per-candidate get_metadata() lookup, parse training_step
from SourceIdentity.extra_parameters, and skip candidates whose
step is below the threshold or unparseable. Cost: extra gRPC
round-trip per candidate; can be removed once training_step is
surfaced on SourceInstanceRef directly.
2. training_publisher.initialize(): training_framework was
hardcoded to "prime_rl" in _build_identity, which mislabeled
verl-published sources. Now a parameter on initialize() (default
"unknown" so callers know to set it explicitly).
3. training_publisher publish_weights / publish_layer mutual
exclusivity: publish_layer registers fresh tensors every call but
publish_weights caches via self._registered, so interleaving the
two paths could leave NIXL holding only the most-recently-
registered tensor set. New self._publish_mode tracks which path
is in use; either method raises if the other was already used on
this publisher.
4. refit_receiver._DTYPE_MAP: lifted to module scope (was rebuilt
per call inside receive_weights_scratch).
Docs — content fixes:
5. VERL_MX_OVERVIEW.md deployment-mode table: replaced ❌ / ✅
emoji markers with plain text per repo "no emojis in markdown"
guideline.
6. PRIMERL_MX_OVERVIEW.md §3.9: fixed duplicate "byte-exact
byte-exact" → "byte-exact".
7. MD040: annotated 15 bare ``` fences across MX_RL_OVERVIEW.md,
PRIMERL_MX_OVERVIEW.md, VERL_MX_OVERVIEW.md,
PRIMERL_MX_NATIVE_DESIGN.md, mx-rl-integration-slides.md as
```text where they were carrying plain prose / ASCII layout.
8. ASCII → mermaid:
- MX_RL_OVERVIEW.md §Architecture: ASCII trainer/server/inference
swimlane → sequenceDiagram.
- PRIMERL_MX_OVERVIEW.md §3.2 DAG buildup: 5-phase ASCII timeline
→ flowchart with one subgraph per phase, plus a per-phase
bandwidth table.
- PRIMERL_MX_OVERVIEW.md §3.9 before/after: naive-allgather vs
overlay per-rank flow → side-by-side flowchart.
- Slide-deck ASCII bottleneck-bar / 3-column architecture
intentionally retained: those are CSS-styled visual fallbacks
for the SVGs, not Markdown rendering targets. The misleading
"[ INSERT DIAGRAM: diagram-architecture.svg ]" placeholder text
above the architecture fallback was removed.
No proto / server-side changes. The poll_for_source fix is the
proto-level workaround documented in the CodeRabbit review; the
forward-looking fix (adding training_step directly to
SourceInstanceRef so we don't need the per-candidate get_metadata)
is a follow-up.
Made-with: Cursor
Signed-off-by: Kavin Krishnan <kavink@nvidia.com>
Adds NIXL_COMPRESSION_STUDY.md to help the NIXL nvCOMP compression team reproduce our RL weight-transfer payloads using our validated PRIME-RL and verl workflows. Three paths documented: 1. Pre-captured data (fastest) — pointer to our existing Qwen2.5-1.5B data package (model.safetensors + pre/post RL weights + deltas + KV cache, captured from live GB200 deployment). 2. End-to-end reproduction on GB200 via the PRIME-RL overlay (PR PrimeIntellect-ai/prime-rl#2343) — deploy scenario A, exec into trainer pod, capture state_dict + simulate one RL step + dump KV cache. Step-by-step with kubectl commands. 3. Reproduction via verl MxCheckpointEngine (PR ai-dynamo/modelexpress #252) — same tensor content, different transport path. Also covers: compression-relevant properties table, per-tensor layout for Qwen3-0.6B and Qwen2.5-1.5B, delta analysis notes (BF16 deltas mostly zero at RL learning rates; FP32 diffs are the meaningful analysis target), NIXL integration point for nvCOMP (transparent — compress/decompress at the NIXL layer, no MX or framework changes), and a model-size scaling table for larger captures. Signed-off-by: Kavin Krishnan <kavink@nvidia.com> Made-with: Cursor
Summary
docs/RL/VERL_MX_OVERVIEW.md: reframes §2 so verl’s nativeNIXLCheckpointEngineis described as the default GPU RDMA path;mxis positioned as an optional catalog + star layer on the sameCheckpointEngine/ bucket / NIXL READ foundation.Test plan
Made with Cursor
Summary by CodeRabbit
New Features
Documentation