debug w4a4 trtllm rollout ipc#6455
Conversation
DAPO/PPO with QAT W4A4 (nvfp4) actor and TRT-LLM rollout failed at update_weights because the multiprocessing IPC path cannot move float8 tensors and TRT-LLM's restricted unpickler blocks the rebuild function. Four patches make the path work: 1. trtllm_rollout.py: register a float8 storage map and stub class so reduce_tensor can pickle float8 (fn / fnuz / e5m2). Before reduce_tensor, move CPU-resident params to CUDA (param_offload=True path) and view the float8 tensor as uint8 storage; return a 3-tuple (rebuild_fn, rebuild_args, orig_dtype_str) so the receiver can restore the original dtype. 2. trtllm_worker_extension.py: mirror the float8 storage map and stub class on the receiver. Monkey-patch tensorrt_llm.serialization.loads to extend its restricted- unpickler allowlist with torch.multiprocessing.reductions. rebuild_cuda_tensor. Unpack the 3-tuple and view back to the original float8 dtype. 3. trtllm_async_server.py: route non-VLM TRT-LLM rollouts through verl's WorkerExtension so the receiver patches above actually take effect for non-VLM models. 4. utils/qat/linear.py: minor QAT linear adjustment carried along. Validated on Qwen3-8B-Base + DAPO + FSDP + TRT-LLM rollout + W4A4 (nvfp4), 2-node Lyris gb200, 9 RL training steps, 80 update_weights ENTER events across all WorkerDict ranks, timing_s/update_weights ~12.76s per step-1+ sync, no Tracebacks / OOM / NaN. Known limitations / out of scope for this PR: - The tensorrt_llm.serialization.loads monkey-patch is a workaround; a proper TRT-LLM API change would be cleaner. - The non-VLM -> verl WorkerExtension change in trtllm_async_server.py affects all non-VLM TRT-LLM rollouts in this tree. - W4A4 KL divergence is a separate documented issue.
There was a problem hiding this comment.
Code Review
This pull request implements NVFP4 Quantization-Aware Training (QAT) support for the TRT-LLM rollout backend, featuring a workaround for IPC-sharing of float8 tensors by casting them to uint8 and monkey-patching serialization to permit required torch imports. It also introduces a force_dynamic_quantization configuration. Feedback highlights the need to remove extensive debug logging and a performance-degrading out-of-bounds detection block in the model forward pass that causes synchronous GPU-to-CPU stalls.
| import sys as _sys | ||
| try: | ||
| _iids = input_ids_rmpad | ||
| if _iids is not None and hasattr(_iids, "numel") and _iids.numel() > 0: | ||
| _mn = int(_iids.min().item()) | ||
| _mx = int(_iids.max().item()) | ||
| _vs = getattr(model.config, "vocab_size", None) if hasattr(model, "config") else None | ||
| print( | ||
| f"RAY_EXECUTOR_DEBUG: model_forward.input_ids_rmpad: " | ||
| f"shape={tuple(_iids.shape)} min={_mn} max={_mx} vocab_size={_vs}", | ||
| file=_sys.stderr, flush=True, | ||
| ) | ||
| if _vs is not None and _mx >= _vs: | ||
| print( | ||
| f"RAY_EXECUTOR_DEBUG: model_forward.OOB_DETECTED: " | ||
| f"max_token_id={_mx} >= vocab_size={_vs}", | ||
| file=_sys.stderr, flush=True, | ||
| ) | ||
| except Exception as _e: | ||
| print(f"RAY_EXECUTOR_DEBUG: model_forward.probe_error: {_e}", file=_sys.stderr, flush=True) |
There was a problem hiding this comment.
This debug block introduces a severe performance regression in the model's forward pass. Calling .min().item() and .max().item() on GPU tensors (lines 288-289) triggers synchronous GPU-to-CPU data transfers, which stalls the computation pipeline. Furthermore, the use of print with flush=True in this hot path will cause significant overhead and log flooding. This entire block should be removed.
| print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps}", flush=True) | ||
| self.checkpoint_manager.update_weights(self.global_steps) | ||
| print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps}", flush=True) |
There was a problem hiding this comment.
Leftover debug print statements found. These should be removed to maintain clean logs. Use the logging module if tracking weight updates is necessary.
| print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps}", flush=True) | |
| self.checkpoint_manager.update_weights(self.global_steps) | |
| print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps}", flush=True) | |
| self.checkpoint_manager.update_weights(self.global_steps) |
| print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps} (critic_warmup)", flush=True) | ||
| self.checkpoint_manager.update_weights(self.global_steps) | ||
| print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps} (critic_warmup)", flush=True) |
There was a problem hiding this comment.
Leftover debug print statements found. These should be removed.
| print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps} (critic_warmup)", flush=True) | |
| self.checkpoint_manager.update_weights(self.global_steps) | |
| print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps} (critic_warmup)", flush=True) | |
| self.checkpoint_manager.update_weights(self.global_steps) |
| print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps}", flush=True) | ||
| self.checkpoint_manager.update_weights(self.global_steps) | ||
| print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps}", flush=True) |
There was a problem hiding this comment.
Leftover debug print statements found. These should be removed.
| print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps}", flush=True) | |
| self.checkpoint_manager.update_weights(self.global_steps) | |
| print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps}", flush=True) | |
| self.checkpoint_manager.update_weights(self.global_steps) |
| print(f"RAY_EXECUTOR_DEBUG: transformer_impl.get_per_tensor_param _qat_enabled={self._qat_enabled} type={type(per_tensor_param).__name__}", flush=True) | ||
| if self._qat_enabled: | ||
| from verl.utils.modelopt import export_qat_weights | ||
|
|
||
| per_tensor_param = export_qat_weights(per_tensor_param, self.module, self._qat_config.mode, self.bridge) | ||
| print(f"RAY_EXECUTOR_DEBUG: transformer_impl.get_per_tensor_param post_qat type={type(per_tensor_param).__name__}", flush=True) | ||
|
|
There was a problem hiding this comment.
Debug print statements in the weight export path should be removed to avoid log pollution during training.
| print(f"RAY_EXECUTOR_DEBUG: transformer_impl.get_per_tensor_param _qat_enabled={self._qat_enabled} type={type(per_tensor_param).__name__}", flush=True) | |
| if self._qat_enabled: | |
| from verl.utils.modelopt import export_qat_weights | |
| per_tensor_param = export_qat_weights(per_tensor_param, self.module, self._qat_config.mode, self.bridge) | |
| print(f"RAY_EXECUTOR_DEBUG: transformer_impl.get_per_tensor_param post_qat type={type(per_tensor_param).__name__}", flush=True) | |
| if self._qat_enabled: | |
| from verl.utils.modelopt import export_qat_weights | |
| per_tensor_param = export_qat_weights(per_tensor_param, self.module, self._qat_config.mode, self.bridge) |
| peft_config=None, so the rollout receives a standard weight update. | ||
| """ | ||
|
|
||
| print(f"RAY_EXECUTOR_DEBUG: engine_workers.update_weights ENTER global_steps={global_steps}", flush=True) |
| print(f"RAY_EXECUTOR_DEBUG: engine_workers.before_rollout_update_weights global_steps={global_steps}", flush=True) | ||
| await self.rollout.update_weights( | ||
| per_tensor_param, peft_config=peft_config, base_sync_done=True, global_steps=global_steps | ||
| ) | ||
| print(f"RAY_EXECUTOR_DEBUG: engine_workers.after_rollout_update_weights global_steps={global_steps}", flush=True) |
Pure observational prints across engine_workers, ray_trainer, model_forward, and engine/megatron/transformer_impl. Four functional W4A4 patches in 7c5660b are untouched.
Two status logs added during the debug session (NVFP4 QAT injection and force_dynamic_quantization flag activation). The functional config injections themselves are unchanged.
Factor the duplicated float8 storage-stub registration and NVFP4 quantization-config builder into _w4a4_compat. Inline a stray result-variable refactor and drop the _fdq tempvar. Revert the unrelated recipe submodule bump that slipped in with 7c5660b. No behavior change.
What does this PR do?
Checklist Before Starting
[{modules}] {type}: {description}(This will be checked by the CI){modules}includefsdp,megatron,veomni,sglang,vllm,rollout,trainer,ci,training_utils,recipe,hardware,deployment,ray,worker,single_controller,misc,perf,model,algo,env,tool,ckpt,doc,data,cfg,reward,fully_async,one_step_off,like[megatron, fsdp, doc]{type}is infeat,fix,refactor,chore,test[BREAKING]to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batchingTest
API and Usage Example
# Add code snippet or script demonstrating how to use thisDesign & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=alwaysci-requestchannel in theverlSlack workspace. (If not accessible, please try the Feishu group (飞书群).)recipesubmodule, please also update the reference to the submodule commit viagit submodule update --remoteorcd recipe && git pull origin main.