Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
501ebcc
docs: add RL weight update integration slide deck
KavinKrishnan Apr 3, 2026
66a546a
docs: add markdown version of RL integration slide deck
KavinKrishnan Apr 4, 2026
9ce518d
feat: add MxTrainingPublisher and MxRefitReceiver for RL weight refit
KavinKrishnan Apr 7, 2026
c34003a
fix: list all sources and filter client-side by model_name (identity …
KavinKrishnan Apr 10, 2026
7e837c3
fix: register NIXL tensors only once per publisher lifetime
KavinKrishnan Apr 10, 2026
f7bcb16
feat: add receive_weights_scratch() for cross-format RDMA transfers
KavinKrishnan Apr 10, 2026
f978b6a
fix: disable transfer coalescing in receive_weights_scratch (incompat…
KavinKrishnan Apr 11, 2026
6b00b0f
fix: accept tensor_shapes in receive_weights_scratch for correct weig…
KavinKrishnan Apr 11, 2026
953b234
chore: remove review/feedback docs from kavink/RL branch
KavinKrishnan Apr 13, 2026
5d8ce71
docs: add MX RL integration overview (PRIME-RL + verl design)
KavinKrishnan Apr 14, 2026
5d46610
docs: add component architecture diagrams for PRIME-RL and verl POCs
KavinKrishnan Apr 14, 2026
b30bf8f
docs: add verl × ModelExpress integration overview with vertical diag…
KavinKrishnan Apr 22, 2026
8991660
docs(verl): remove related-documents section with internal-only paths
KavinKrishnan Apr 22, 2026
90e45ea
docs(RL): cross-reference draft overlay PR #2343 in the design doc
KavinKrishnan Apr 22, 2026
861bac2
docs(RL): refine VERL_MX_OVERVIEW for native nixl alignment and catal…
KavinKrishnan Apr 23, 2026
6ccec5d
docs(RL): add Path B native MX design as alternative to PI overlay
KavinKrishnan Apr 24, 2026
0e10aba
docs(RL): clarify v0.1 overlay scope in PRIMERL_MX diagrams
KavinKrishnan Apr 24, 2026
45341f3
docs(RL): backfill PRIMERL_MX_OVERVIEW with Scenario A GB200 results
KavinKrishnan Apr 24, 2026
b16d620
docs(RL): backfill PRIMERL_MX_OVERVIEW with B + C results
KavinKrishnan Apr 24, 2026
4c7e1df
fix(RL): address CodeRabbit review on PR #252
KavinKrishnan Apr 27, 2026
dc2856d
docs(RL): add NIXL compression study reproduction guide
KavinKrishnan Apr 29, 2026
5be0cef
docs(RL): clarify NIXL compression data package is request-only
KavinKrishnan Apr 29, 2026
848e7f7
docs(RL): publish NIXL compression study capture scripts
KavinKrishnan Apr 29, 2026
e8aefab
docs(RL): genericize NIXL compression study + add component diagram
KavinKrishnan Apr 29, 2026
16ce4fe
docs(RL): drop named NIXL inbox from compression study doc
KavinKrishnan Apr 29, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
389 changes: 389 additions & 0 deletions docs/MX_RL_OVERVIEW.md

Large diffs are not rendered by default.

303 changes: 303 additions & 0 deletions docs/RL/NIXL_COMPRESSION_STUDY.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,303 @@
# NIXL nvCOMP Compression Study — Reproducing with ModelExpress RL Workflows

**Last Updated**: April 29, 2026
**Audience**: NIXL compression team
**Purpose**: Guide the NIXL team to capture and study real RL weight-transfer payloads using our validated PRIME-RL and verl workflows with ModelExpress (MX).

---

## Background

The NIXL team is evaluating nvCOMP GPU compression on the tensors that flow through NIXL during RL post-training. There are two transfer types:

1. **RL refit** (training → inference): full model weights, every RL step.
2. **KV cache** (prefill → decode): per-request KV tensors in disaggregated inference.

We have **two validated end-to-end RL workflows** that produce these payloads over NIXL on GB200:

| Workflow | Framework | Status | PR | What it exercises |
|----------|-----------|--------|-----|-------------------|
| **PRIME-RL overlay** | PRIME-RL + vLLM | Scenarios A/B/C green on GB200 (20/20 steps each) | [PrimeIntellect-ai/prime-rl#2343](https://github.com/PrimeIntellect-ai/prime-rl/pull/2343) | NIXL RDMA weight push via PI's `NIXLWeightBroadcast` + `TransportPlan`, MX-mediated discovery |
| **verl MxCheckpointEngine** | verl + vLLM | 10 steps green on GB200 | [ai-dynamo/modelexpress#252](https://github.com/ai-dynamo/modelexpress/pull/252) | NIXL RDMA weight transfer via `MxCheckpointEngine` (`CheckpointEngine` plugin) |

Both produce the **exact same kind of data** the NIXL team requested: raw BF16 weight tensors flowing GPU-to-GPU over NIXL, plus pre/post RL-step weight deltas for delta-compression analysis.

---

## Component View

Where the data being studied actually lives + what writes/reads it. Green is what the compression team would compress; purple is the existing RL stack producing it.

```mermaid
flowchart TB
subgraph trainer_node["Trainer node — GB200"]
direction TB
T_FSDP["FSDP2 trainer<br/>optimizer.step()"]
T_PUB["MxTrainingPublisher<br/>+ NIXLWeightBroadcast"]
T_NIXL(["NIXL agent (UCX rc_mlx5)"])
T_FSDP --> T_PUB --> T_NIXL
end

subgraph mx_meta["Metadata plane — MX Server"]
MX["MX Server<br/>(gRPC)"]
REDIS[("Redis")]
MX --> REDIS
end

subgraph inference_node["Inference node — GB200"]
direction TB
I_NIXL(["NIXL agent (UCX rc_mlx5)"])
I_RECV["MxRefitReceiver<br/>+ NIXLWeightUpdateWorker"]
VLLM["vLLM engine<br/>(live params)"]
I_NIXL --> I_RECV --> VLLM
end

subgraph prefill_node["Prefill worker — GB200 (disagg inference)"]
direction TB
P_FWD["vLLM prefill pass"]
P_KV["KV cache buffer<br/>(per-layer key/value)"]
P_FWD --> P_KV
end

subgraph decode_node["Decode worker — GB200"]
direction TB
D_KV["KV cache import"]
D_GEN["Token generation"]
D_KV --> D_GEN
end

T_PUB -. "publish metadata<br/>(SourceIdentity, agent blob)" .-> MX
MX -. "discover" .-> I_RECV

T_NIXL ==> |"① RL REFIT<br/>weights (BF16)<br/>~3 GB / step (1.5B model)<br/>~140 GB / step (70B)"| I_NIXL
P_KV ==> |"② KV CACHE<br/>tensors (BF16)<br/>~14 MB at seq=512<br/>~3.5 GB at seq=131K"| D_KV

style T_FSDP fill:#533483,stroke:#e94560,color:#fff
style T_PUB fill:#533483,stroke:#e94560,color:#fff
style T_NIXL fill:#1b5e20,stroke:#4caf50,color:#fff
style I_NIXL fill:#1b5e20,stroke:#4caf50,color:#fff
style I_RECV fill:#533483,stroke:#e94560,color:#fff
style VLLM fill:#533483,stroke:#e94560,color:#fff
style P_FWD fill:#533483,stroke:#e94560,color:#fff
style P_KV fill:#1b5e20,stroke:#4caf50,color:#fff
style D_KV fill:#1b5e20,stroke:#4caf50,color:#fff
style D_GEN fill:#533483,stroke:#e94560,color:#fff
style MX fill:#533483,stroke:#e94560,color:#fff
style REDIS fill:#162447,stroke:#533483,color:#e0e0e0
```

**Compression target = the green edges.** ① is the RL-refit path between trainer and inference NIXL agents; ② is the KV cache transfer between prefill and decode. Everything purple — the trainer, the MX Server, vLLM, the receiver — is RL-stack infrastructure that wouldn't change if nvCOMP is added at the NIXL layer (compression would slot in transparently between `register` and `RDMA WRITE` on either edge).

The capture scripts in [`scripts/`](./scripts/) snapshot the bytes that cross those green edges, plus pre/post weight tensors for delta-compression analysis.

---

## Option 1: Request the pre-captured data package (fastest)

We have a ready-made data package captured from a live PRIME-RL deployment on GB200. **It's not in this repo** (binary tensors at GB scale aren't appropriate to commit) — request access from `kavink@nvidia.com` and we'll share via the appropriate channel (NV S3 bucket, internal share, or direct upload).

Package contents:

```text
RL_Qwen25/
├── model.safetensors # 2.9 GB — all 338 weight tensors (BF16)
├── weights_pre_rl.safetensors # 3.4 GB — weights before optimizer.step()
├── weights_post_rl.safetensors # 3.4 GB — weights after 1 AdamW step (lr=5e-6)
├── weight_deltas.safetensors # 3.4 GB — elementwise diff (post - pre), BF16
├── kv_cache/ # 14 MB — 56 KV tensors from a 501-token prefill
│ ├── layer_0_key.bin # shape [1, 2, 501, 128], BF16
│ ├── layer_0_value.bin
│ ├── ...
│ └── manifest.json # per-tensor metadata
├── manifest.json # 66 KB — per-weight-tensor metadata
└── README.md # full layout + compression properties
```

**Model**: Qwen2.5-1.5B BF16, 28 layers, 1.54B parameters. ~14 GB total package size.

**How to read**:

```python
from safetensors import safe_open
import torch

# Weights (the exact tensors NIXL transfers during RL refit)
with safe_open("model.safetensors", framework="pt") as f:
for key in f.keys():
tensor = f.get_tensor(key) # torch.bfloat16
raw_bytes = tensor.contiguous().untyped_storage() # raw bytes as on the wire
print(f"{key}: {tensor.shape}, {len(raw_bytes)} bytes")

# Weight delta (for delta-compression analysis — compute in FP32 for precision)
pre = safe_open("weights_pre_rl.safetensors", framework="pt")
post = safe_open("weights_post_rl.safetensors", framework="pt")
for key in pre.keys():
delta = post.get_tensor(key).float() - pre.get_tensor(key).float()
print(f"{key}: max_abs_delta={delta.abs().max():.2e}")

# KV cache (the exact tensors transferred prefill → decode via NIXL)
raw = open("kv_cache/layer_0_key.bin", "rb").read()
kv = torch.frombuffer(bytearray(raw), dtype=torch.bfloat16).reshape(1, 2, 501, 128)
```

**Key finding on deltas**: At BF16 precision, single-step RL deltas are mostly zero (AdamW updates at lr=5e-6 are below BF16's representable precision). For meaningful delta analysis, compute diffs in FP32. This suggests delta-compression should operate in FP32 and quantize back after.

---

## Option 2: Reproduce end-to-end on GB200 (PRIME-RL overlay)

Run our validated PRIME-RL overlay workflow and capture weights mid-flight using the published [`scripts/`](./scripts/) directory.

### Prerequisites

- A GB200 cluster (ARM64) with at least 2 nodes, container runtime, and an RDMA-capable interconnect (InfiniBand or RoCE) between nodes. Cluster orchestration is Kubernetes-based; manifests assume `kubectl` access and a working namespace.
- A namespace where you'll deploy the overlay, with the following bound:
- MX Server reachable at `modelexpress-server.<your-ns>.svc.cluster.local:8001` (Helm chart in this repo, or use an existing deployment)
- Redis backing the MX Server
- A shared model-cache PVC for HuggingFace downloads
- An image pull secret for the registry hosting the overlay image (we publish to `nvcr.io/nvidian/dynamo-dev/`; you can also build locally)

The included K8s manifests under `prime-rl/k8s/prime-rl-mx-on-nixl/` may need light edits (namespace, node selectors, image pull secret name, RDMA network annotations) for your cluster — they're examples, not portable across all GB200 deployments. The capture flow itself is cluster-agnostic.

### Step 1: Deploy the PRIME-RL overlay

```bash
git clone git@github.com:KavinKrishnan/prime-rl.git
cd prime-rl
git checkout kavink/mx-on-nixl

# Use the pre-built ARM64 image, or build locally
# Pre-built: nvcr.io/nvidian/dynamo-dev/prime-rl-mx-on-nixl:v0.2
docker buildx build --platform linux/arm64 \
-f docker/Dockerfile.mx-on-nixl \
-t <your-registry>/prime-rl-mx-on-nixl:v0.2 \
--push .

# Edit k8s/prime-rl-mx-on-nixl/*.yaml for your namespace, node selectors,
# RDMA network annotations, and image registry. Then:
cd k8s/prime-rl-mx-on-nixl
./run.sh deploy A # scenario A = PI's NIXL transport, no MX env vars
./run.sh status # wait until all 3 pods are Running
```

### Step 2: Verify the RL loop is running

```bash
NS=<your-namespace>
kubectl -n $NS logs prime-rl-mx-on-nixl-trainer-0 --tail=20 | grep "SUCCESS.*Step"
kubectl -n $NS logs prime-rl-mx-on-nixl-inference-0 | grep "update_weights.*200"
```

### Step 3: Capture using the published script

We ship `capture_on_pod.py` in [`scripts/`](./scripts/) — same script that produced our pre-captured Qwen2.5-1.5B package. It captures pre/post RL weights, simulates one AdamW step, computes deltas, and dumps a KV cache prefill, all in one pass.

```bash
NS=<your-namespace>

# Copy the script into the trainer pod
kubectl cp docs/RL/scripts/capture_on_pod.py \
$NS/prime-rl-mx-on-nixl-trainer-0:/tmp/capture.py

# Run it inside the pod (overlay image's interpreter is /app/.venv/bin/python)
kubectl exec $NS/prime-rl-mx-on-nixl-trainer-0 -- /app/.venv/bin/python /tmp/capture.py \
--model Qwen/Qwen2.5-1.5B \
--out /tmp/nixl_capture \
--kv-seq-len 512 \
--lr 5e-6

# Copy the results back
kubectl cp $NS/prime-rl-mx-on-nixl-trainer-0:/tmp/nixl_capture ./RL_capture
```

Output `RL_capture/` contains four sub-directories (`weights_pre_rl/`, `weights_post_rl/`, `weight_deltas/`, `kv_cache/`) each with raw `.bin` files plus a `manifest.json`. See [`scripts/README.md`](./scripts/README.md) for the full layout + flag reference.

### Step 4 (optional): Capture without a running RL deployment

If reproducing the overlay is more cluster work than the data is worth, [`scripts/capture_weights_and_kv.py`](./scripts/capture_weights_and_kv.py) is the **standalone** variant — works on any host (CPU or single GPU), no Kubernetes / RL framework required:

```bash
pip install torch transformers safetensors

python docs/RL/scripts/capture_weights_and_kv.py \
--model Qwen/Qwen2.5-1.5B \
--output-dir ./nixl_data \
--dtype bfloat16 \
--device cpu
```

Doesn't simulate an RL step (no pre/post/delta), but produces the same weight + KV cache layout the NIXL team can compress against.

### Step 5: Tear down

```bash
./run.sh clean
```

---

## Option 3: Reproduce with verl MxCheckpointEngine

The verl integration uses the same MX client but through verl's `CheckpointEngine` plugin. This path captures the weights as they flow through `MxCheckpointEngine.send_weights()` / `receive_weights()`.

Deployment docs: `docs/RL/VERL_MX_OVERVIEW.md` §6 in the modelexpress repo.

The capture approach is the same as Option 2 (exec into the trainer pod, save state dict pre/post step) since the weight tensors are identical — both frameworks produce `model.named_parameters()` in BF16. The difference is the transport path (verl's bucket+ZMQ metadata vs prime-rl's TransportPlan+slot system), which doesn't affect the tensor content.

---

## What to capture for the compression study

| Artifact | File | Size (Qwen3-0.6B) | What it represents |
|----------|------|-------|---------------------|
| **Current weights** | `weights_current.safetensors` | ~1.2 GB | Exact tensors registered with NIXL and RDMA-written to inference GPU every RL step |
| **Post-step weights** | `weights_post_step.safetensors` | ~1.2 GB | After one AdamW step (lr=5e-6) |
| **Weight deltas** | `weight_deltas.safetensors` | ~1.2 GB | `post - pre` in BF16 (mostly zero — compute in FP32 for real deltas) |
| **KV cache** | `kv_cache/*.bin` | ~14 MB | Prefill output transferred to decode workers via NIXL |
| **Manifest** | `manifest.json` | ~30 KB | Per-tensor: name, shape, dtype, size_bytes |

### Larger models for more representative data

The steps above use Qwen3-0.6B (our scenario A model). For larger models closer to production:

| Model | Params | Weight payload | Notes |
|-------|--------|----------------|-------|
| Qwen3-0.6B (above) | 0.6B | ~1.2 GB | Validated in PR #2343 scenarios A/B/C |
| Qwen2.5-1.5B | 1.5B | ~3 GB | Pre-captured package available on request (see Option 1) |
| Qwen2.5-7B | 7.6B | ~15 GB | T1 model in our overlay plan |
| Qwen3-MoE (PI offered spec) | MoE | varies | Would exercise `ExpertSlot` + per-expert tensors — most representative for MoE compression |

For models requiring multiple GPUs, the weights are FSDP-sharded — each rank's shard is `total / num_ranks` in size. The bytes on the wire per-rank are the shard size, not the full model.

---

## Compression-relevant properties

| Property | Weights | KV Cache | Delta (FP32) |
|----------|---------|----------|--------------|
| **Dtype on wire** | BF16 (2 B/elem) | BF16 (2 B/elem) | BF16 stored, but FP32 is the meaningful analysis dtype |
| **Value distribution** | Normal, centered ~0, std 0.01–0.1 | Wider, context-dependent | Very small magnitude (~1e-8 to 1e-6 per element) |
| **Sparsity** | Dense (no zeros) | Dense | ~100% zero at BF16 precision; structured-sparse at FP32 |
| **Best compression angle** | Entropy coding on mantissa bits | Temporal locality across layers | FP32 delta + entropy coding — high compressibility expected |
| **Transfer frequency** | Every RL step (~5–60 s) | Every request | Once for analysis |
| **Bucket size on wire** | 596 MB (measured in scenario A/B/C) | per-request, scales with seq_len | N/A |

### NIXL integration point for nvCOMP

If nvCOMP compression is added at the NIXL layer, the integration is transparent to both MX and the RL frameworks:

```text
Current:
Training GPU → NIXL register → RDMA WRITE (raw bytes) → Inference GPU

With NIXL-layer nvCOMP:
Training GPU → NIXL register → nvCOMP compress (GPU) → RDMA WRITE (compressed) → nvCOMP decompress (GPU) → Inference GPU
```

No changes to `MxTrainingPublisher`, `MxRefitReceiver`, `NIXLWeightBroadcast`, `TransportPlan`, or the MX Server protocol. Compression is internal to NIXL's transfer path. Our bucket-streaming pattern is preserved — compression happens per-bucket.

---

## Questions?

Reach out to Kavin Krishnan (`kavink@nvidia.com`) for access to the pre-captured data or help reproducing on a cluster. The PRIME-RL overlay branch (`KavinKrishnan/prime-rl:kavink/mx-on-nixl`) and the modelexpress RL branch (`ai-dynamo/modelexpress:kavink/RL`) are the entry points.
Loading
Loading