Skip to content

debug w4a4 trtllm rollout ipc#6455

Draft
Superjomn wants to merge 5 commits into
verl-project:mainfrom
Superjomn:chunweiy/w4a4-trtllm-rollout-ipc
Draft

debug w4a4 trtllm rollout ipc#6455
Superjomn wants to merge 5 commits into
verl-project:mainfrom
Superjomn:chunweiy/w4a4-trtllm-rollout-ipc

Conversation

@Superjomn
Copy link
Copy Markdown
Collaborator

What does this PR do?

Add concise overview of what this PR aims to achieve or accomplish. Reference related GitHub issues and PRs that help with the review.

Checklist Before Starting

  • Search for similar PRs. Paste at least one query link here: ...
  • Format the PR title as [{modules}] {type}: {description} (This will be checked by the CI)
    • {modules} include fsdp, megatron, veomni, sglang, vllm, rollout, trainer, ci, training_utils, recipe, hardware, deployment, ray, worker, single_controller, misc, perf, model, algo, env, tool, ckpt, doc, data, cfg, reward, fully_async, one_step_off
    • If this PR involves multiple modules, separate them with , like [megatron, fsdp, doc]
    • {type} is in feat, fix, refactor, chore, test
    • If this PR breaks any API (CLI arguments, config, function signature, etc.), add [BREAKING] to the beginning of the title.
    • Example: [BREAKING][fsdp, megatron] feat: dynamic batching

Test

For changes that can not be tested by CI (e.g., algorithm implementation, new model support), validate by experiment(s) and show results like training curve plots, evaluation results, etc.

API and Usage Example

Demonstrate how the API changes if any, and provide usage example(s) if possible.

# Add code snippet or script demonstrating how to use this

Design & Code Changes

Demonstrate the high-level design if this PR is complex, and list the specific changes.

Checklist Before Submitting

Important

Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.

Superjomn added 2 commits May 16, 2026 18:16
DAPO/PPO with QAT W4A4 (nvfp4) actor and TRT-LLM rollout failed at
update_weights because the multiprocessing IPC path cannot move
float8 tensors and TRT-LLM's restricted unpickler blocks the
rebuild function. Four patches make the path work:

1. trtllm_rollout.py: register a float8 storage map and stub class
   so reduce_tensor can pickle float8 (fn / fnuz / e5m2). Before
   reduce_tensor, move CPU-resident params to CUDA (param_offload=True
   path) and view the float8 tensor as uint8 storage; return a
   3-tuple (rebuild_fn, rebuild_args, orig_dtype_str) so the
   receiver can restore the original dtype.

2. trtllm_worker_extension.py: mirror the float8 storage map and
   stub class on the receiver. Monkey-patch
   tensorrt_llm.serialization.loads to extend its restricted-
   unpickler allowlist with torch.multiprocessing.reductions.
   rebuild_cuda_tensor. Unpack the 3-tuple and view back to the
   original float8 dtype.

3. trtllm_async_server.py: route non-VLM TRT-LLM rollouts through
   verl's WorkerExtension so the receiver patches above actually
   take effect for non-VLM models.

4. utils/qat/linear.py: minor QAT linear adjustment carried along.

Validated on Qwen3-8B-Base + DAPO + FSDP + TRT-LLM rollout + W4A4
(nvfp4), 2-node Lyris gb200, 9 RL training steps, 80
update_weights ENTER events across all WorkerDict ranks,
timing_s/update_weights ~12.76s per step-1+ sync, no Tracebacks /
OOM / NaN.

Known limitations / out of scope for this PR:
- The tensorrt_llm.serialization.loads monkey-patch is a
  workaround; a proper TRT-LLM API change would be cleaner.
- The non-VLM -> verl WorkerExtension change in
  trtllm_async_server.py affects all non-VLM TRT-LLM rollouts in
  this tree.
- W4A4 KL divergence is a separate documented issue.
@Superjomn Superjomn marked this pull request as draft May 24, 2026 12:49
@Superjomn Superjomn changed the title Chunweiy/w4a4 trtllm rollout ipc debug w4a4 trtllm rollout ipc May 24, 2026
@Superjomn Superjomn closed this May 24, 2026
Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements NVFP4 Quantization-Aware Training (QAT) support for the TRT-LLM rollout backend, featuring a workaround for IPC-sharing of float8 tensors by casting them to uint8 and monkey-patching serialization to permit required torch imports. It also introduces a force_dynamic_quantization configuration. Feedback highlights the need to remove extensive debug logging and a performance-degrading out-of-bounds detection block in the model forward pass that causes synchronous GPU-to-CPU stalls.

Comment thread verl/models/mcore/model_forward.py Outdated
Comment on lines +284 to +303
import sys as _sys
try:
_iids = input_ids_rmpad
if _iids is not None and hasattr(_iids, "numel") and _iids.numel() > 0:
_mn = int(_iids.min().item())
_mx = int(_iids.max().item())
_vs = getattr(model.config, "vocab_size", None) if hasattr(model, "config") else None
print(
f"RAY_EXECUTOR_DEBUG: model_forward.input_ids_rmpad: "
f"shape={tuple(_iids.shape)} min={_mn} max={_mx} vocab_size={_vs}",
file=_sys.stderr, flush=True,
)
if _vs is not None and _mx >= _vs:
print(
f"RAY_EXECUTOR_DEBUG: model_forward.OOB_DETECTED: "
f"max_token_id={_mx} >= vocab_size={_vs}",
file=_sys.stderr, flush=True,
)
except Exception as _e:
print(f"RAY_EXECUTOR_DEBUG: model_forward.probe_error: {_e}", file=_sys.stderr, flush=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This debug block introduces a severe performance regression in the model's forward pass. Calling .min().item() and .max().item() on GPU tensors (lines 288-289) triggers synchronous GPU-to-CPU data transfers, which stalls the computation pipeline. Furthermore, the use of print with flush=True in this hot path will cause significant overhead and log flooding. This entire block should be removed.

Comment thread verl/trainer/ppo/ray_trainer.py Outdated
Comment on lines +1303 to +1305
print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps}", flush=True)
self.checkpoint_manager.update_weights(self.global_steps)
print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps}", flush=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Leftover debug print statements found. These should be removed to maintain clean logs. Use the logging module if tracking weight updates is necessary.

Suggested change
print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps}", flush=True)
self.checkpoint_manager.update_weights(self.global_steps)
print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps}", flush=True)
self.checkpoint_manager.update_weights(self.global_steps)

Comment thread verl/trainer/ppo/ray_trainer.py Outdated
Comment on lines +1559 to +1561
print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps} (critic_warmup)", flush=True)
self.checkpoint_manager.update_weights(self.global_steps)
print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps} (critic_warmup)", flush=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Leftover debug print statements found. These should be removed.

Suggested change
print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps} (critic_warmup)", flush=True)
self.checkpoint_manager.update_weights(self.global_steps)
print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps} (critic_warmup)", flush=True)
self.checkpoint_manager.update_weights(self.global_steps)

Comment thread verl/trainer/ppo/ray_trainer.py Outdated
Comment on lines +1591 to +1593
print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps}", flush=True)
self.checkpoint_manager.update_weights(self.global_steps)
print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps}", flush=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Leftover debug print statements found. These should be removed.

Suggested change
print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps}", flush=True)
self.checkpoint_manager.update_weights(self.global_steps)
print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps}", flush=True)
self.checkpoint_manager.update_weights(self.global_steps)

Comment on lines 733 to 739
print(f"RAY_EXECUTOR_DEBUG: transformer_impl.get_per_tensor_param _qat_enabled={self._qat_enabled} type={type(per_tensor_param).__name__}", flush=True)
if self._qat_enabled:
from verl.utils.modelopt import export_qat_weights

per_tensor_param = export_qat_weights(per_tensor_param, self.module, self._qat_config.mode, self.bridge)
print(f"RAY_EXECUTOR_DEBUG: transformer_impl.get_per_tensor_param post_qat type={type(per_tensor_param).__name__}", flush=True)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Debug print statements in the weight export path should be removed to avoid log pollution during training.

Suggested change
print(f"RAY_EXECUTOR_DEBUG: transformer_impl.get_per_tensor_param _qat_enabled={self._qat_enabled} type={type(per_tensor_param).__name__}", flush=True)
if self._qat_enabled:
from verl.utils.modelopt import export_qat_weights
per_tensor_param = export_qat_weights(per_tensor_param, self.module, self._qat_config.mode, self.bridge)
print(f"RAY_EXECUTOR_DEBUG: transformer_impl.get_per_tensor_param post_qat type={type(per_tensor_param).__name__}", flush=True)
if self._qat_enabled:
from verl.utils.modelopt import export_qat_weights
per_tensor_param = export_qat_weights(per_tensor_param, self.module, self._qat_config.mode, self.bridge)

Comment thread verl/workers/engine_workers.py Outdated
peft_config=None, so the rollout receives a standard weight update.
"""

print(f"RAY_EXECUTOR_DEBUG: engine_workers.update_weights ENTER global_steps={global_steps}", flush=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Remove debug print statement.

Comment thread verl/workers/engine_workers.py Outdated
Comment on lines +710 to +714
print(f"RAY_EXECUTOR_DEBUG: engine_workers.before_rollout_update_weights global_steps={global_steps}", flush=True)
await self.rollout.update_weights(
per_tensor_param, peft_config=peft_config, base_sync_done=True, global_steps=global_steps
)
print(f"RAY_EXECUTOR_DEBUG: engine_workers.after_rollout_update_weights global_steps={global_steps}", flush=True)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Remove debug print statements.

        await self.rollout.update_weights(
            per_tensor_param, peft_config=peft_config, base_sync_done=True, global_steps=global_steps
        )

Superjomn added 2 commits May 24, 2026 06:24
Pure observational prints across engine_workers, ray_trainer,
model_forward, and engine/megatron/transformer_impl. Four
functional W4A4 patches in 7c5660b are untouched.
Two status logs added during the debug session (NVFP4 QAT injection
and force_dynamic_quantization flag activation). The functional
config injections themselves are unchanged.
@Superjomn Superjomn reopened this May 24, 2026
Factor the duplicated float8 storage-stub registration and NVFP4
quantization-config builder into _w4a4_compat. Inline a stray
result-variable refactor and drop the _fdq tempvar. Revert the
unrelated recipe submodule bump that slipped in with 7c5660b.

No behavior change.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant