From 8ce6b404a45333d73a4fa98b90684b2a5c4ea239 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sat, 16 May 2026 18:16:59 -0700 Subject: [PATCH 1/5] wip: 30b-qat-rotary-oob debug session edits (TRT-LLM rollout + megatron transformer) --- recipe | 2 +- verl/models/mcore/model_forward.py | 21 ++++++++++ verl/trainer/config/rollout/rollout.yaml | 4 ++ verl/trainer/ppo/ray_trainer.py | 6 +++ verl/workers/config/rollout.py | 2 + .../engine/megatron/transformer_impl.py | 2 + verl/workers/engine_workers.py | 3 ++ .../trtllm_rollout/trtllm_async_server.py | 42 +++++++++++++++++++ .../rollout/trtllm_rollout/trtllm_rollout.py | 23 +++++++++- .../trtllm_rollout/trtllm_worker_extension.py | 2 + 10 files changed, 105 insertions(+), 2 deletions(-) diff --git a/recipe b/recipe index ba246418f4d..e7f889574b8 160000 --- a/recipe +++ b/recipe @@ -1 +1 @@ -Subproject commit ba246418f4de12b845a09bba975f1a5242adc898 +Subproject commit e7f889574b8301cc0f0fc1d57c6d67f31ffeb689 diff --git a/verl/models/mcore/model_forward.py b/verl/models/mcore/model_forward.py index 885e1f69730..f74f32a65b3 100644 --- a/verl/models/mcore/model_forward.py +++ b/verl/models/mcore/model_forward.py @@ -281,6 +281,27 @@ def gptmodel_forward_model_engine( if vision_model: input_ids_rmpad, attention_mask = build_vlm_attn_mask_thd(input_ids, pad_token_id) + import sys as _sys + try: + _iids = input_ids_rmpad + if _iids is not None and hasattr(_iids, "numel") and _iids.numel() > 0: + _mn = int(_iids.min().item()) + _mx = int(_iids.max().item()) + _vs = getattr(model.config, "vocab_size", None) if hasattr(model, "config") else None + print( + f"RAY_EXECUTOR_DEBUG: model_forward.input_ids_rmpad: " + f"shape={tuple(_iids.shape)} min={_mn} max={_mx} vocab_size={_vs}", + file=_sys.stderr, flush=True, + ) + if _vs is not None and _mx >= _vs: + print( + f"RAY_EXECUTOR_DEBUG: model_forward.OOB_DETECTED: " + f"max_token_id={_mx} >= vocab_size={_vs}", + file=_sys.stderr, flush=True, + ) + except Exception as _e: + print(f"RAY_EXECUTOR_DEBUG: model_forward.probe_error: {_e}", file=_sys.stderr, flush=True) + output_orig = model( input_ids=input_ids_rmpad, attention_mask=attention_mask, diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index a0f3eb92670..1536951be2b 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -446,3 +446,7 @@ mtp: ${oc.select:actor_rollout_ref.model.mtp, null} # QAT configuration (inherited from actor's engine config) qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}} + +# Experimental: force TRT-LLM to dynamically re-quantize incoming weights at reload time. +# Only consumed by the TRT-LLM rollout backend (see workers/rollout/trtllm_rollout/trtllm_async_server.py). +force_dynamic_quantization: False diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 08a4c436cee..846e09a3a58 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -1300,7 +1300,9 @@ def fit(self): # load checkpoint and update weights before doing anything self._load_checkpoint() + print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps}", flush=True) self.checkpoint_manager.update_weights(self.global_steps) + print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps}", flush=True) current_epoch = self.global_steps // len(self.train_dataloader) @@ -1554,7 +1556,9 @@ def fit(self): # implement critic warmup if self.config.trainer.critic_warmup > self.global_steps: # Still in critic warmup, only update weights to wake up rollout replicas. + print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps} (critic_warmup)", flush=True) self.checkpoint_manager.update_weights(self.global_steps) + print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps} (critic_warmup)", flush=True) else: # update actor with marked_timer("update_actor", timing_raw, color="red"): @@ -1584,7 +1588,9 @@ def fit(self): # update weights from trainer to rollout with marked_timer("update_weights", timing_raw, color="red"): + print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps}", flush=True) self.checkpoint_manager.update_weights(self.global_steps) + print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps}", flush=True) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index fa8b9573fa6..a3e2422e0e9 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -287,6 +287,8 @@ class RolloutConfig(BaseConfig): qat: Optional[dict] = None + force_dynamic_quantization: bool = False + def __post_init__(self): """Validate the rollout config""" # Deprecation warning for mode field - only async mode is supported diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index cbb7be48141..1956c8c517b 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -730,10 +730,12 @@ def get_per_tensor_param(self, base_sync_done=False, **kwargs): ) # QAT: process weights through QATWeightExporter for quantized weight sync to vLLM + print(f"RAY_EXECUTOR_DEBUG: transformer_impl.get_per_tensor_param _qat_enabled={self._qat_enabled} type={type(per_tensor_param).__name__}", flush=True) if self._qat_enabled: from verl.utils.modelopt import export_qat_weights per_tensor_param = export_qat_weights(per_tensor_param, self.module, self._qat_config.mode, self.bridge) + print(f"RAY_EXECUTOR_DEBUG: transformer_impl.get_per_tensor_param post_qat type={type(per_tensor_param).__name__}", flush=True) return per_tensor_param, peft_config diff --git a/verl/workers/engine_workers.py b/verl/workers/engine_workers.py index 74ecb711def..832a01fb0ad 100644 --- a/verl/workers/engine_workers.py +++ b/verl/workers/engine_workers.py @@ -673,6 +673,7 @@ async def update_weights(self, global_steps: int = None): peft_config=None, so the rollout receives a standard weight update. """ + print(f"RAY_EXECUTOR_DEBUG: engine_workers.update_weights ENTER global_steps={global_steps}", flush=True) # 0. send_weights only for async training with disaggregated trainer and rollout if self.config.rollout.checkpoint_engine.backend != "naive": per_tensor_param, _ = self.actor.engine.get_per_tensor_param() @@ -706,9 +707,11 @@ async def update_weights(self, global_steps: int = None): per_tensor_param_base, peft_config=peft_config, base_sync_done=False, global_steps=global_steps ) + print(f"RAY_EXECUTOR_DEBUG: engine_workers.before_rollout_update_weights global_steps={global_steps}", flush=True) await self.rollout.update_weights( per_tensor_param, peft_config=peft_config, base_sync_done=True, global_steps=global_steps ) + print(f"RAY_EXECUTOR_DEBUG: engine_workers.after_rollout_update_weights global_steps={global_steps}", flush=True) log_gpu_memory_usage("After update_weights", logger=logger) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 0a1c1147495..3da8baa467a 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -154,6 +154,25 @@ async def launch_server(self): else: raise ValueError(f"Currently only support fp8 quantization, got: {quantization}") + # NVFP4 QAT — actor (Megatron+modelopt) emits real-quantized NVFP4 weights; + # rollout must build NVFP4-packed linears to receive them. + qat_cfg = getattr(self.config, "qat", None) + if qat_cfg is not None and qat_cfg.get("enable", False): + import json as _json + _quant_json_path = qat_cfg.get("quantization_config_path") + with open(_quant_json_path) as _f: + _quant_json = _json.load(_f) + _gs = (_quant_json.get("config_groups", {}).get("group_0", {}) + .get("weights", {}) or {}).get("group_size", 16) + engine_kwargs.setdefault("model_kwargs", {})["quantization_config"] = { + "quant_method": "nvfp4", + "group_size": _gs, + "modules_to_not_convert": _quant_json.get("ignore", []) or ["lm_head"], + } + if self.config.load_format != "dummy": + raise ValueError("NVFP4 QAT rollout requires load_format=dummy") + logger.info(f"NVFP4 QAT injected: group_size={_gs}") + llm_kwargs = { "model": self.model_config.local_path, "backend": "pytorch", @@ -220,6 +239,13 @@ async def launch_server(self): } ) + # Experimental: force TRT-LLM to dynamically re-quantize incoming weights + # at reload time. Reads actor_rollout_ref.rollout.force_dynamic_quantization. + _fdq = bool(getattr(self.config, "force_dynamic_quantization", False)) + if _fdq: + llm_kwargs["force_dynamic_quantization"] = True + logger.info("force_dynamic_quantization=True set on TRT-LLM rollout") + self.llm = await AsyncLLM(**llm_kwargs) import inspect @@ -285,6 +311,22 @@ async def generate( sampling_params=trt_llm_sampling_params, ) token_ids = outputs.outputs[0].token_ids + try: + import sys as _dbg_sys + _tids = list(token_ids) if token_ids is not None else [] + if _tids: + _mn = min(_tids); _mx = max(_tids) + _bad = _mx >= 200000 or _mn < 0 + print( + f"RAY_EXECUTOR_DEBUG: rollout.generate.token_ids: " + f"len={len(_tids)} min={_mn} max={_mx} bad={_bad} " + f"first5={_tids[:5]} last5={_tids[-5:]} " + f"finish_reason={outputs.outputs[0].finish_reason}", + file=_dbg_sys.stderr, flush=True, + ) + except Exception as _e: + import sys as _dbg_sys + print(f"RAY_EXECUTOR_DEBUG: rollout.generate.probe_error: {_e}", file=_dbg_sys.stderr, flush=True) log_probs = None if outputs.outputs[0].logprobs is not None: # When logprobs=1, TRT-LLM returns only the sampled token's logprob at each position diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py index 4fbf0e614f4..8ee387a79e0 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py @@ -262,7 +262,10 @@ async def update_weights(self, weights: dict[str, str]): Returns: Dict[str, Any]: Server response containing update status """ - return await self._make_async_request("update_weights", {"weights": weights}) + print(f"RAY_EXECUTOR_DEBUG: AsyncTRTLLMHttpAdapter.update_weights ENTER n_weights={len(weights) if weights else 0}", flush=True) + result = await self._make_async_request("update_weights", {"weights": weights}) + print(f"RAY_EXECUTOR_DEBUG: AsyncTRTLLMHttpAdapter.update_weights EXIT", flush=True) + return result class ServerAdapter(BaseRollout): @@ -300,6 +303,24 @@ def __init__( } fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS) model_config.hf_config.quantization_config = fp8_block_quant_kwargs + + # NVFP4 QAT — mirror actor's quant config onto hf_config so ServerAdapter + # weight-sync sees NVFP4 metadata (same role as the FP8 block above). + qat_cfg = config.get("qat", None) + if qat_cfg is not None and qat_cfg.get("enable", False): + qat_path = qat_cfg.get("quantization_config_path") + if qat_path: + import json as _json + with open(qat_path) as _f: + _q = _json.load(_f) + _gs = (_q.get("config_groups", {}).get("group_0", {}) + .get("weights", {}) or {}).get("group_size", 16) + model_config.hf_config.quantization_config = { + "quant_method": "nvfp4", + "group_size": _gs, + "modules_to_not_convert": _q.get("ignore", []) or ["lm_head"], + } + super().__init__(config, model_config, device_mesh) self._adapter = None self.hybrid_device_mesh = None diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index 0fc5a06cd8e..a998f50e6fd 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -54,6 +54,7 @@ def supports_partial_loading(self) -> bool: @control_action_decorator def update_weights(self, ipc_handles: Optional[dict] = None): + print(f"RAY_EXECUTOR_DEBUG: WorkerExtension.update_weights ENTER ipc_handles={'set' if ipc_handles is not None else 'None'}", flush=True) try: if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"): for module in self.engine.model_engine.model.modules(): @@ -107,6 +108,7 @@ def update_weights(self, ipc_handles: Optional[dict] = None): "device", "float32", "float16", + "bfloat16", "int32", "int64", "int16", From 7c5660b16b1e3a17e03f5bd87a50efe98a48ee34 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sun, 24 May 2026 05:47:29 -0700 Subject: [PATCH 2/5] [rollout] feat: enable W4A4 float8 actor->TRT-LLM rollout weight sync DAPO/PPO with QAT W4A4 (nvfp4) actor and TRT-LLM rollout failed at update_weights because the multiprocessing IPC path cannot move float8 tensors and TRT-LLM's restricted unpickler blocks the rebuild function. Four patches make the path work: 1. trtllm_rollout.py: register a float8 storage map and stub class so reduce_tensor can pickle float8 (fn / fnuz / e5m2). Before reduce_tensor, move CPU-resident params to CUDA (param_offload=True path) and view the float8 tensor as uint8 storage; return a 3-tuple (rebuild_fn, rebuild_args, orig_dtype_str) so the receiver can restore the original dtype. 2. trtllm_worker_extension.py: mirror the float8 storage map and stub class on the receiver. Monkey-patch tensorrt_llm.serialization.loads to extend its restricted- unpickler allowlist with torch.multiprocessing.reductions. rebuild_cuda_tensor. Unpack the 3-tuple and view back to the original float8 dtype. 3. trtllm_async_server.py: route non-VLM TRT-LLM rollouts through verl's WorkerExtension so the receiver patches above actually take effect for non-VLM models. 4. utils/qat/linear.py: minor QAT linear adjustment carried along. Validated on Qwen3-8B-Base + DAPO + FSDP + TRT-LLM rollout + W4A4 (nvfp4), 2-node Lyris gb200, 9 RL training steps, 80 update_weights ENTER events across all WorkerDict ranks, timing_s/update_weights ~12.76s per step-1+ sync, no Tracebacks / OOM / NaN. Known limitations / out of scope for this PR: - The tensorrt_llm.serialization.loads monkey-patch is a workaround; a proper TRT-LLM API change would be cleaner. - The non-VLM -> verl WorkerExtension change in trtllm_async_server.py affects all non-VLM TRT-LLM rollouts in this tree. - W4A4 KL divergence is a separate documented issue. --- verl/utils/qat/linear.py | 3 +- .../trtllm_rollout/trtllm_async_server.py | 30 ++------ .../rollout/trtllm_rollout/trtllm_rollout.py | 53 ++++++++++++-- .../trtllm_rollout/trtllm_worker_extension.py | 71 ++++++++++++++++++- 4 files changed, 126 insertions(+), 31 deletions(-) diff --git a/verl/utils/qat/linear.py b/verl/utils/qat/linear.py index 4b6c6bc8f41..b580972fa28 100644 --- a/verl/utils/qat/linear.py +++ b/verl/utils/qat/linear.py @@ -220,8 +220,9 @@ def __init__( self._fusion_siblings_ref = None if mode == QATMode.W4A4: + # Seed to 1.0 (not the -1.0 sentinel) so step-0 update_weights passes before any actor forward runs to populate the real scale. self.register_buffer( - "input_global_scale", torch.tensor([self._UNINITIALIZED_SCALE], dtype=torch.float32), persistent=True + "input_global_scale", torch.tensor([1.0], dtype=torch.float32), persistent=True ) self.register_buffer( diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 3da8baa467a..e1bf7cc494e 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -206,17 +206,15 @@ async def launch_server(self): **engine_kwargs, } + # Always use verl's WorkerExtension. The module-level imports register + # float8 storage stubs and monkey-patch tensorrt_llm.serialization.loads + # to allow torch.storage._load_from_bytes — required for W4A4 weight_scale + # (float8_e4m3fn) IPC. Without this, non-VLM models fell back to the + # parent TRT-LLM extension and the patches never ran on the server. self_defined_extension = { "ray_worker_extension_cls": "verl.workers.rollout.trtllm_rollout.trtllm_worker_extension.WorkerExtension", } - if self.is_vlm_model: - llm_kwargs.update(self_defined_extension) - else: - llm_kwargs.update( - { - "ray_worker_extension_cls": "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", - } - ) + llm_kwargs.update(self_defined_extension) if self.is_reward_model: llm_kwargs.update( @@ -311,22 +309,6 @@ async def generate( sampling_params=trt_llm_sampling_params, ) token_ids = outputs.outputs[0].token_ids - try: - import sys as _dbg_sys - _tids = list(token_ids) if token_ids is not None else [] - if _tids: - _mn = min(_tids); _mx = max(_tids) - _bad = _mx >= 200000 or _mn < 0 - print( - f"RAY_EXECUTOR_DEBUG: rollout.generate.token_ids: " - f"len={len(_tids)} min={_mn} max={_mx} bad={_bad} " - f"first5={_tids[:5]} last5={_tids[-5:]} " - f"finish_reason={outputs.outputs[0].finish_reason}", - file=_dbg_sys.stderr, flush=True, - ) - except Exception as _e: - import sys as _dbg_sys - print(f"RAY_EXECUTOR_DEBUG: rollout.generate.probe_error: {_e}", file=_dbg_sys.stderr, flush=True) log_probs = None if outputs.outputs[0].logprobs is not None: # When logprobs=1, TRT-LLM returns only the sampled token's logprob at each position diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py index 8ee387a79e0..1171ab2ee58 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py @@ -30,6 +30,38 @@ import torch import torch.distributed as dist +# Workaround: torch's legacy storage pickle path (used by +# multiprocessing.reductions.reduce_tensor when IPC-sharing tensors) is missing +# float8_e4m3fn / float8_e5m2 entries in some container builds. W4A4 QAT yields +# weight_scale as float8_e4m3fn (quantizer.py:66), so update_weights crashes +# inside pickle.dumps(cur_handles). Register the entries at import time so the +# legacy path round-trips float8 storages losslessly. +try: + from torch import storage as _torch_storage + _fwd = _torch_storage._dtype_to_storage_type_map + _bwd = _torch_storage._storage_type_to_dtype_map + if callable(_fwd): + _fwd = _fwd() + if callable(_bwd): + _bwd = _bwd() + for _dt, _name in ( + (torch.float8_e4m3fn, "Float8_e4m3fnStorage"), + (torch.float8_e5m2, "Float8_e5m2Storage"), + ): + if _dt not in _fwd: + _fwd[_dt] = _name + if _name not in _bwd: + _bwd[_name] = _dt + # The legacy save path (torch/serialization.py persistent_id) calls + # `getattr(torch, storage_type_str)` to embed module+name in the pickle + # stream. Register a stub class that subclasses UntypedStorage and + # carries the dtype, so getattr succeeds and round-trip works. + if not hasattr(torch, _name): + _stub = type(_name, (torch.UntypedStorage,), {"dtype": _dt, "__module__": "torch"}) + setattr(torch, _name, _stub) +except (AttributeError, ImportError): + pass + try: from tensorrt_llm.llmapi.llm_args import ExecutorMemoryType except (ImportError, RuntimeError): @@ -262,9 +294,7 @@ async def update_weights(self, weights: dict[str, str]): Returns: Dict[str, Any]: Server response containing update status """ - print(f"RAY_EXECUTOR_DEBUG: AsyncTRTLLMHttpAdapter.update_weights ENTER n_weights={len(weights) if weights else 0}", flush=True) result = await self._make_async_request("update_weights", {"weights": weights}) - print(f"RAY_EXECUTOR_DEBUG: AsyncTRTLLMHttpAdapter.update_weights EXIT", flush=True) return result @@ -517,8 +547,23 @@ async def flush(): ) cur_available_bytes -= size_in_bytes - handle = reduce_tensor(param.detach()) - cur_handles.append((name, handle)) + t = param.detach() + # With param_offload=True, weights may arrive on CPU; reduce_tensor + # falls back to legacy save for CPU tensors, yielding handle args + # incompatible with the receiver (which assumes rebuild_cuda_tensor + # 15-arg layout). Move to CUDA so reduce_tensor takes the IPC path. + if not t.is_cuda: + t = t.cuda() + # Float8 dtypes (e.g. W4A4 weight_scale) don't take CUDA IPC fast + # path in reduce_tensor; they fall through to legacy save which + # yields incompatible handle args. View as uint8 for transport and + # record original dtype so the receiver can view it back. + orig_dtype_str = None + if t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + orig_dtype_str = str(t.dtype).split(".")[-1] + t = t.view(torch.uint8) + handle = reduce_tensor(t) + cur_handles.append((name, handle, orig_dtype_str)) await flush() diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index a998f50e6fd..d9d59519338 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -15,6 +15,35 @@ import inspect from typing import Optional +import torch + +# Register float8 dtypes in torch's legacy storage pickle maps and create stub +# storage classes so the legacy save/load path round-trips W4A4 weight_scale +# (float8_e4m3fn) tensors via IPC. Must run in BOTH the verl rollout worker +# (sender) and the TRT-LLM rollout worker (receiver). This module is imported +# on the TRT-LLM side via WorkerExtension registration. +try: + from torch import storage as _torch_storage + _fwd = _torch_storage._dtype_to_storage_type_map + _bwd = _torch_storage._storage_type_to_dtype_map + if callable(_fwd): + _fwd = _fwd() + if callable(_bwd): + _bwd = _bwd() + for _dt, _name in ( + (torch.float8_e4m3fn, "Float8_e4m3fnStorage"), + (torch.float8_e5m2, "Float8_e5m2Storage"), + ): + if _dt not in _fwd: + _fwd[_dt] = _name + if _name not in _bwd: + _bwd[_name] = _dt + if not hasattr(torch, _name): + _stub = type(_name, (torch.UntypedStorage,), {"dtype": _dt, "__module__": "torch"}) + setattr(torch, _name, _stub) +except (AttributeError, ImportError): + pass + # Defer tensorrt_llm imports to avoid FlashInfer's check_cuda_arch() crash # when this module is loaded on CPU-only Ray actors. The module is normally # loaded only on GPU workers via string path in trtllm_async_server.py, but @@ -37,6 +66,36 @@ logger = None +# Monkey-patch tensorrt_llm.serialization.loads to permit torch.storage._load_from_bytes +# and float8 storage stubs. The parent rlhf_utils.WorkerExtension.update_weights calls +# serialization.loads with a hard-coded approved_imports that omits these — needed for +# W4A4 weight_scale (float8_e4m3fn) IPC round-trip. +if serialization is not None and not getattr(serialization, "_verl_w4a4_patched", False): + _orig_loads = serialization.loads + + def _patched_loads(data, approved_imports=None, **kwargs): + extra = { + "torch.storage": ["_load_from_bytes", "UntypedStorage", "TypedStorage", "_TypedStorage"], + "torch._utils": ["_rebuild_tensor_v2"], + "torch.multiprocessing.reductions": ["rebuild_cuda_tensor", "rebuild_tensor"], + "torch": ["Tensor", "Size", "dtype", "device", "Float8_e4m3fnStorage", "Float8_e5m2Storage", + "float8_e4m3fn", "float8_e5m2"], + } + if approved_imports is None: + approved_imports = {} + approved_imports = dict(approved_imports) + for _mod, _names in extra.items(): + merged = list(approved_imports.get(_mod, [])) + for _n in _names: + if _n not in merged: + merged.append(_n) + approved_imports[_mod] = merged + return _orig_loads(data, approved_imports=approved_imports, **kwargs) + + serialization.loads = _patched_loads + serialization._verl_w4a4_patched = True + + class WorkerExtension(TrtllmWorkerExtension): def __init__(self): pass @@ -54,7 +113,6 @@ def supports_partial_loading(self) -> bool: @control_action_decorator def update_weights(self, ipc_handles: Optional[dict] = None): - print(f"RAY_EXECUTOR_DEBUG: WorkerExtension.update_weights ENTER ipc_handles={'set' if ipc_handles is not None else 'None'}", flush=True) try: if not hasattr(self.engine.model_engine.model, "first_pre_reload_weights"): for module in self.engine.model_engine.model.modules(): @@ -142,11 +200,20 @@ def update_weights(self, ipc_handles: Optional[dict] = None): # Data is already in the correct format (backward compatibility) all_handles = serialized_handles - for param_name, tensor_handle in all_handles: + for entry in all_handles: + # Support both 2-tuple (legacy) and 3-tuple (with orig_dtype_str + # for float8 transport-as-uint8 round-trip). + if len(entry) == 3: + param_name, tensor_handle, orig_dtype_str = entry + else: + param_name, tensor_handle = entry + orig_dtype_str = None func, args = tensor_handle list_args = list(args) list_args[6] = self.device_id tensor = func(*list_args) + if orig_dtype_str is not None: + tensor = tensor.view(getattr(torch, orig_dtype_str)) weights[param_name] = tensor logger.info(f"weights key size: {len(weights.keys())}") From 73610447dab39909130e2171ca6bf1eee50efb49 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sun, 24 May 2026 06:24:53 -0700 Subject: [PATCH 3/5] [chore] remove RAY_EXECUTOR_DEBUG probes from W4A4 debug session Pure observational prints across engine_workers, ray_trainer, model_forward, and engine/megatron/transformer_impl. Four functional W4A4 patches in 7c5660b1 are untouched. --- verl/models/mcore/model_forward.py | 21 ------------------- verl/trainer/ppo/ray_trainer.py | 6 ------ .../engine/megatron/transformer_impl.py | 2 -- verl/workers/engine_workers.py | 3 --- 4 files changed, 32 deletions(-) diff --git a/verl/models/mcore/model_forward.py b/verl/models/mcore/model_forward.py index f74f32a65b3..885e1f69730 100644 --- a/verl/models/mcore/model_forward.py +++ b/verl/models/mcore/model_forward.py @@ -281,27 +281,6 @@ def gptmodel_forward_model_engine( if vision_model: input_ids_rmpad, attention_mask = build_vlm_attn_mask_thd(input_ids, pad_token_id) - import sys as _sys - try: - _iids = input_ids_rmpad - if _iids is not None and hasattr(_iids, "numel") and _iids.numel() > 0: - _mn = int(_iids.min().item()) - _mx = int(_iids.max().item()) - _vs = getattr(model.config, "vocab_size", None) if hasattr(model, "config") else None - print( - f"RAY_EXECUTOR_DEBUG: model_forward.input_ids_rmpad: " - f"shape={tuple(_iids.shape)} min={_mn} max={_mx} vocab_size={_vs}", - file=_sys.stderr, flush=True, - ) - if _vs is not None and _mx >= _vs: - print( - f"RAY_EXECUTOR_DEBUG: model_forward.OOB_DETECTED: " - f"max_token_id={_mx} >= vocab_size={_vs}", - file=_sys.stderr, flush=True, - ) - except Exception as _e: - print(f"RAY_EXECUTOR_DEBUG: model_forward.probe_error: {_e}", file=_sys.stderr, flush=True) - output_orig = model( input_ids=input_ids_rmpad, attention_mask=attention_mask, diff --git a/verl/trainer/ppo/ray_trainer.py b/verl/trainer/ppo/ray_trainer.py index 846e09a3a58..08a4c436cee 100644 --- a/verl/trainer/ppo/ray_trainer.py +++ b/verl/trainer/ppo/ray_trainer.py @@ -1300,9 +1300,7 @@ def fit(self): # load checkpoint and update weights before doing anything self._load_checkpoint() - print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps}", flush=True) self.checkpoint_manager.update_weights(self.global_steps) - print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps}", flush=True) current_epoch = self.global_steps // len(self.train_dataloader) @@ -1556,9 +1554,7 @@ def fit(self): # implement critic warmup if self.config.trainer.critic_warmup > self.global_steps: # Still in critic warmup, only update weights to wake up rollout replicas. - print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps} (critic_warmup)", flush=True) self.checkpoint_manager.update_weights(self.global_steps) - print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps} (critic_warmup)", flush=True) else: # update actor with marked_timer("update_actor", timing_raw, color="red"): @@ -1588,9 +1584,7 @@ def fit(self): # update weights from trainer to rollout with marked_timer("update_weights", timing_raw, color="red"): - print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights ENTER step={self.global_steps}", flush=True) self.checkpoint_manager.update_weights(self.global_steps) - print(f"RAY_EXECUTOR_DEBUG: ray_trainer.update_weights EXIT step={self.global_steps}", flush=True) actor_output_metrics = reduce_metrics(actor_output.meta_info["metrics"]) metrics.update(actor_output_metrics) diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index 1956c8c517b..cbb7be48141 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -730,12 +730,10 @@ def get_per_tensor_param(self, base_sync_done=False, **kwargs): ) # QAT: process weights through QATWeightExporter for quantized weight sync to vLLM - print(f"RAY_EXECUTOR_DEBUG: transformer_impl.get_per_tensor_param _qat_enabled={self._qat_enabled} type={type(per_tensor_param).__name__}", flush=True) if self._qat_enabled: from verl.utils.modelopt import export_qat_weights per_tensor_param = export_qat_weights(per_tensor_param, self.module, self._qat_config.mode, self.bridge) - print(f"RAY_EXECUTOR_DEBUG: transformer_impl.get_per_tensor_param post_qat type={type(per_tensor_param).__name__}", flush=True) return per_tensor_param, peft_config diff --git a/verl/workers/engine_workers.py b/verl/workers/engine_workers.py index 832a01fb0ad..74ecb711def 100644 --- a/verl/workers/engine_workers.py +++ b/verl/workers/engine_workers.py @@ -673,7 +673,6 @@ async def update_weights(self, global_steps: int = None): peft_config=None, so the rollout receives a standard weight update. """ - print(f"RAY_EXECUTOR_DEBUG: engine_workers.update_weights ENTER global_steps={global_steps}", flush=True) # 0. send_weights only for async training with disaggregated trainer and rollout if self.config.rollout.checkpoint_engine.backend != "naive": per_tensor_param, _ = self.actor.engine.get_per_tensor_param() @@ -707,11 +706,9 @@ async def update_weights(self, global_steps: int = None): per_tensor_param_base, peft_config=peft_config, base_sync_done=False, global_steps=global_steps ) - print(f"RAY_EXECUTOR_DEBUG: engine_workers.before_rollout_update_weights global_steps={global_steps}", flush=True) await self.rollout.update_weights( per_tensor_param, peft_config=peft_config, base_sync_done=True, global_steps=global_steps ) - print(f"RAY_EXECUTOR_DEBUG: engine_workers.after_rollout_update_weights global_steps={global_steps}", flush=True) log_gpu_memory_usage("After update_weights", logger=logger) From eaa7030727bcb2dd8d394d80b7a0359951e404f7 Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sun, 24 May 2026 07:05:27 -0700 Subject: [PATCH 4/5] [chore] drop observational logger.info from W4A4 rollout init Two status logs added during the debug session (NVFP4 QAT injection and force_dynamic_quantization flag activation). The functional config injections themselves are unchanged. --- verl/workers/rollout/trtllm_rollout/trtllm_async_server.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index e1bf7cc494e..23f8e358a9e 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -171,7 +171,6 @@ async def launch_server(self): } if self.config.load_format != "dummy": raise ValueError("NVFP4 QAT rollout requires load_format=dummy") - logger.info(f"NVFP4 QAT injected: group_size={_gs}") llm_kwargs = { "model": self.model_config.local_path, @@ -242,7 +241,6 @@ async def launch_server(self): _fdq = bool(getattr(self.config, "force_dynamic_quantization", False)) if _fdq: llm_kwargs["force_dynamic_quantization"] = True - logger.info("force_dynamic_quantization=True set on TRT-LLM rollout") self.llm = await AsyncLLM(**llm_kwargs) import inspect From 84270ebc4b20a09defc309dfca2ba03b21c46e3a Mon Sep 17 00:00:00 2001 From: Superjomn Date: Sun, 24 May 2026 07:35:45 -0700 Subject: [PATCH 5/5] [chore] dedup W4A4 helpers; drop unrelated submodule bump Factor the duplicated float8 storage-stub registration and NVFP4 quantization-config builder into _w4a4_compat. Inline a stray result-variable refactor and drop the _fdq tempvar. Revert the unrelated recipe submodule bump that slipped in with 7c5660b1. No behavior change. --- .../rollout/trtllm_rollout/_w4a4_compat.py | 74 +++++++++++++++++++ .../trtllm_rollout/trtllm_async_server.py | 18 ++--- .../rollout/trtllm_rollout/trtllm_rollout.py | 51 +------------ .../trtllm_rollout/trtllm_worker_extension.py | 28 +------ 4 files changed, 84 insertions(+), 87 deletions(-) create mode 100644 verl/workers/rollout/trtllm_rollout/_w4a4_compat.py diff --git a/verl/workers/rollout/trtllm_rollout/_w4a4_compat.py b/verl/workers/rollout/trtllm_rollout/_w4a4_compat.py new file mode 100644 index 00000000000..2a9a7042f50 --- /dev/null +++ b/verl/workers/rollout/trtllm_rollout/_w4a4_compat.py @@ -0,0 +1,74 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Internal helpers shared by the verl-side TRT-LLM rollout adapter and the +TRT-LLM rollout worker extension for W4A4 (NVFP4 QAT) weight sync over IPC.""" + +import json +import torch + + +def _register_float8_storage_stubs() -> None: + """Register float8_e4m3fn / float8_e5m2 in torch's legacy storage pickle + maps and create UntypedStorage stub classes so the legacy save/load path + round-trips W4A4 weight_scale tensors via IPC. + + Must run on both the sender (verl rollout worker) and receiver (TRT-LLM + rollout worker). Safe to call multiple times. + """ + try: + from torch import storage as torch_storage + fwd = torch_storage._dtype_to_storage_type_map + bwd = torch_storage._storage_type_to_dtype_map + if callable(fwd): + fwd = fwd() + if callable(bwd): + bwd = bwd() + for dt, name in ( + (torch.float8_e4m3fn, "Float8_e4m3fnStorage"), + (torch.float8_e5m2, "Float8_e5m2Storage"), + ): + if dt not in fwd: + fwd[dt] = name + if name not in bwd: + bwd[name] = dt + if not hasattr(torch, name): + stub = type(name, (torch.UntypedStorage,), {"dtype": dt, "__module__": "torch"}) + setattr(torch, name, stub) + except (AttributeError, ImportError): + pass + + +def build_nvfp4_quantization_config(qat_cfg) -> dict: + """Build the HF/TRT-LLM quantization_config dict from a verl QAT config. + + `qat_cfg` is the rollout-side QAT sub-config (OmegaConf-style with + `.get(...)`). Returns the dict that goes into either + `engine_kwargs["model_kwargs"]["quantization_config"]` (TRT-LLM side) or + `model_config.hf_config.quantization_config` (verl ServerAdapter side). + """ + quant_json_path = qat_cfg.get("quantization_config_path") + with open(quant_json_path) as f: + quant_json = json.load(f) + group_size = ( + (quant_json.get("config_groups", {}).get("group_0", {}).get("weights", {}) or {}) + .get("group_size", 16) + ) + return { + "quant_method": "nvfp4", + "group_size": group_size, + "modules_to_not_convert": quant_json.get("ignore", []) or ["lm_head"], + } + + +_register_float8_storage_stubs() diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 23f8e358a9e..394e9ce4599 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -158,17 +158,10 @@ async def launch_server(self): # rollout must build NVFP4-packed linears to receive them. qat_cfg = getattr(self.config, "qat", None) if qat_cfg is not None and qat_cfg.get("enable", False): - import json as _json - _quant_json_path = qat_cfg.get("quantization_config_path") - with open(_quant_json_path) as _f: - _quant_json = _json.load(_f) - _gs = (_quant_json.get("config_groups", {}).get("group_0", {}) - .get("weights", {}) or {}).get("group_size", 16) - engine_kwargs.setdefault("model_kwargs", {})["quantization_config"] = { - "quant_method": "nvfp4", - "group_size": _gs, - "modules_to_not_convert": _quant_json.get("ignore", []) or ["lm_head"], - } + from verl.workers.rollout.trtllm_rollout._w4a4_compat import build_nvfp4_quantization_config + engine_kwargs.setdefault("model_kwargs", {})["quantization_config"] = ( + build_nvfp4_quantization_config(qat_cfg) + ) if self.config.load_format != "dummy": raise ValueError("NVFP4 QAT rollout requires load_format=dummy") @@ -238,8 +231,7 @@ async def launch_server(self): # Experimental: force TRT-LLM to dynamically re-quantize incoming weights # at reload time. Reads actor_rollout_ref.rollout.force_dynamic_quantization. - _fdq = bool(getattr(self.config, "force_dynamic_quantization", False)) - if _fdq: + if getattr(self.config, "force_dynamic_quantization", False): llm_kwargs["force_dynamic_quantization"] = True self.llm = await AsyncLLM(**llm_kwargs) diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py index 1171ab2ee58..6c081db3cca 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py @@ -29,38 +29,7 @@ import ray import torch import torch.distributed as dist - -# Workaround: torch's legacy storage pickle path (used by -# multiprocessing.reductions.reduce_tensor when IPC-sharing tensors) is missing -# float8_e4m3fn / float8_e5m2 entries in some container builds. W4A4 QAT yields -# weight_scale as float8_e4m3fn (quantizer.py:66), so update_weights crashes -# inside pickle.dumps(cur_handles). Register the entries at import time so the -# legacy path round-trips float8 storages losslessly. -try: - from torch import storage as _torch_storage - _fwd = _torch_storage._dtype_to_storage_type_map - _bwd = _torch_storage._storage_type_to_dtype_map - if callable(_fwd): - _fwd = _fwd() - if callable(_bwd): - _bwd = _bwd() - for _dt, _name in ( - (torch.float8_e4m3fn, "Float8_e4m3fnStorage"), - (torch.float8_e5m2, "Float8_e5m2Storage"), - ): - if _dt not in _fwd: - _fwd[_dt] = _name - if _name not in _bwd: - _bwd[_name] = _dt - # The legacy save path (torch/serialization.py persistent_id) calls - # `getattr(torch, storage_type_str)` to embed module+name in the pickle - # stream. Register a stub class that subclasses UntypedStorage and - # carries the dtype, so getattr succeeds and round-trip works. - if not hasattr(torch, _name): - _stub = type(_name, (torch.UntypedStorage,), {"dtype": _dt, "__module__": "torch"}) - setattr(torch, _name, _stub) -except (AttributeError, ImportError): - pass +from verl.workers.rollout.trtllm_rollout import _w4a4_compat # noqa: F401 (registers float8 storage stubs at import) try: from tensorrt_llm.llmapi.llm_args import ExecutorMemoryType @@ -294,8 +263,7 @@ async def update_weights(self, weights: dict[str, str]): Returns: Dict[str, Any]: Server response containing update status """ - result = await self._make_async_request("update_weights", {"weights": weights}) - return result + return await self._make_async_request("update_weights", {"weights": weights}) class ServerAdapter(BaseRollout): @@ -337,19 +305,8 @@ def __init__( # NVFP4 QAT — mirror actor's quant config onto hf_config so ServerAdapter # weight-sync sees NVFP4 metadata (same role as the FP8 block above). qat_cfg = config.get("qat", None) - if qat_cfg is not None and qat_cfg.get("enable", False): - qat_path = qat_cfg.get("quantization_config_path") - if qat_path: - import json as _json - with open(qat_path) as _f: - _q = _json.load(_f) - _gs = (_q.get("config_groups", {}).get("group_0", {}) - .get("weights", {}) or {}).get("group_size", 16) - model_config.hf_config.quantization_config = { - "quant_method": "nvfp4", - "group_size": _gs, - "modules_to_not_convert": _q.get("ignore", []) or ["lm_head"], - } + if qat_cfg is not None and qat_cfg.get("enable", False) and qat_cfg.get("quantization_config_path"): + model_config.hf_config.quantization_config = _w4a4_compat.build_nvfp4_quantization_config(qat_cfg) super().__init__(config, model_config, device_mesh) self._adapter = None diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index d9d59519338..73d43182d2e 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -16,33 +16,7 @@ from typing import Optional import torch - -# Register float8 dtypes in torch's legacy storage pickle maps and create stub -# storage classes so the legacy save/load path round-trips W4A4 weight_scale -# (float8_e4m3fn) tensors via IPC. Must run in BOTH the verl rollout worker -# (sender) and the TRT-LLM rollout worker (receiver). This module is imported -# on the TRT-LLM side via WorkerExtension registration. -try: - from torch import storage as _torch_storage - _fwd = _torch_storage._dtype_to_storage_type_map - _bwd = _torch_storage._storage_type_to_dtype_map - if callable(_fwd): - _fwd = _fwd() - if callable(_bwd): - _bwd = _bwd() - for _dt, _name in ( - (torch.float8_e4m3fn, "Float8_e4m3fnStorage"), - (torch.float8_e5m2, "Float8_e5m2Storage"), - ): - if _dt not in _fwd: - _fwd[_dt] = _name - if _name not in _bwd: - _bwd[_name] = _dt - if not hasattr(torch, _name): - _stub = type(_name, (torch.UntypedStorage,), {"dtype": _dt, "__module__": "torch"}) - setattr(torch, _name, _stub) -except (AttributeError, ImportError): - pass +from verl.workers.rollout.trtllm_rollout import _w4a4_compat # noqa: F401 (registers float8 storage stubs at import) # Defer tensorrt_llm imports to avoid FlashInfer's check_cuda_arch() crash # when this module is loaded on CPU-only Ray actors. The module is normally