diff --git a/recipe b/recipe index ba246418f4d..e7f889574b8 160000 --- a/recipe +++ b/recipe @@ -1 +1 @@ -Subproject commit ba246418f4de12b845a09bba975f1a5242adc898 +Subproject commit e7f889574b8301cc0f0fc1d57c6d67f31ffeb689 diff --git a/verl/trainer/config/rollout/rollout.yaml b/verl/trainer/config/rollout/rollout.yaml index a0f3eb92670..1536951be2b 100644 --- a/verl/trainer/config/rollout/rollout.yaml +++ b/verl/trainer/config/rollout/rollout.yaml @@ -446,3 +446,7 @@ mtp: ${oc.select:actor_rollout_ref.model.mtp, null} # QAT configuration (inherited from actor's engine config) qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}} + +# Experimental: force TRT-LLM to dynamically re-quantize incoming weights at reload time. +# Only consumed by the TRT-LLM rollout backend (see workers/rollout/trtllm_rollout/trtllm_async_server.py). +force_dynamic_quantization: False diff --git a/verl/utils/qat/linear.py b/verl/utils/qat/linear.py index 4b6c6bc8f41..b580972fa28 100644 --- a/verl/utils/qat/linear.py +++ b/verl/utils/qat/linear.py @@ -220,8 +220,9 @@ def __init__( self._fusion_siblings_ref = None if mode == QATMode.W4A4: + # Seed to 1.0 (not the -1.0 sentinel) so step-0 update_weights passes before any actor forward runs to populate the real scale. self.register_buffer( - "input_global_scale", torch.tensor([self._UNINITIALIZED_SCALE], dtype=torch.float32), persistent=True + "input_global_scale", torch.tensor([1.0], dtype=torch.float32), persistent=True ) self.register_buffer( diff --git a/verl/workers/config/rollout.py b/verl/workers/config/rollout.py index fa8b9573fa6..a3e2422e0e9 100644 --- a/verl/workers/config/rollout.py +++ b/verl/workers/config/rollout.py @@ -287,6 +287,8 @@ class RolloutConfig(BaseConfig): qat: Optional[dict] = None + force_dynamic_quantization: bool = False + def __post_init__(self): """Validate the rollout config""" # Deprecation warning for mode field - only async mode is supported diff --git a/verl/workers/rollout/trtllm_rollout/_w4a4_compat.py b/verl/workers/rollout/trtllm_rollout/_w4a4_compat.py new file mode 100644 index 00000000000..2a9a7042f50 --- /dev/null +++ b/verl/workers/rollout/trtllm_rollout/_w4a4_compat.py @@ -0,0 +1,74 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Internal helpers shared by the verl-side TRT-LLM rollout adapter and the +TRT-LLM rollout worker extension for W4A4 (NVFP4 QAT) weight sync over IPC.""" + +import json +import torch + + +def _register_float8_storage_stubs() -> None: + """Register float8_e4m3fn / float8_e5m2 in torch's legacy storage pickle + maps and create UntypedStorage stub classes so the legacy save/load path + round-trips W4A4 weight_scale tensors via IPC. + + Must run on both the sender (verl rollout worker) and receiver (TRT-LLM + rollout worker). Safe to call multiple times. + """ + try: + from torch import storage as torch_storage + fwd = torch_storage._dtype_to_storage_type_map + bwd = torch_storage._storage_type_to_dtype_map + if callable(fwd): + fwd = fwd() + if callable(bwd): + bwd = bwd() + for dt, name in ( + (torch.float8_e4m3fn, "Float8_e4m3fnStorage"), + (torch.float8_e5m2, "Float8_e5m2Storage"), + ): + if dt not in fwd: + fwd[dt] = name + if name not in bwd: + bwd[name] = dt + if not hasattr(torch, name): + stub = type(name, (torch.UntypedStorage,), {"dtype": dt, "__module__": "torch"}) + setattr(torch, name, stub) + except (AttributeError, ImportError): + pass + + +def build_nvfp4_quantization_config(qat_cfg) -> dict: + """Build the HF/TRT-LLM quantization_config dict from a verl QAT config. + + `qat_cfg` is the rollout-side QAT sub-config (OmegaConf-style with + `.get(...)`). Returns the dict that goes into either + `engine_kwargs["model_kwargs"]["quantization_config"]` (TRT-LLM side) or + `model_config.hf_config.quantization_config` (verl ServerAdapter side). + """ + quant_json_path = qat_cfg.get("quantization_config_path") + with open(quant_json_path) as f: + quant_json = json.load(f) + group_size = ( + (quant_json.get("config_groups", {}).get("group_0", {}).get("weights", {}) or {}) + .get("group_size", 16) + ) + return { + "quant_method": "nvfp4", + "group_size": group_size, + "modules_to_not_convert": quant_json.get("ignore", []) or ["lm_head"], + } + + +_register_float8_storage_stubs() diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py index 0a1c1147495..394e9ce4599 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_async_server.py @@ -154,6 +154,17 @@ async def launch_server(self): else: raise ValueError(f"Currently only support fp8 quantization, got: {quantization}") + # NVFP4 QAT — actor (Megatron+modelopt) emits real-quantized NVFP4 weights; + # rollout must build NVFP4-packed linears to receive them. + qat_cfg = getattr(self.config, "qat", None) + if qat_cfg is not None and qat_cfg.get("enable", False): + from verl.workers.rollout.trtllm_rollout._w4a4_compat import build_nvfp4_quantization_config + engine_kwargs.setdefault("model_kwargs", {})["quantization_config"] = ( + build_nvfp4_quantization_config(qat_cfg) + ) + if self.config.load_format != "dummy": + raise ValueError("NVFP4 QAT rollout requires load_format=dummy") + llm_kwargs = { "model": self.model_config.local_path, "backend": "pytorch", @@ -187,17 +198,15 @@ async def launch_server(self): **engine_kwargs, } + # Always use verl's WorkerExtension. The module-level imports register + # float8 storage stubs and monkey-patch tensorrt_llm.serialization.loads + # to allow torch.storage._load_from_bytes — required for W4A4 weight_scale + # (float8_e4m3fn) IPC. Without this, non-VLM models fell back to the + # parent TRT-LLM extension and the patches never ran on the server. self_defined_extension = { "ray_worker_extension_cls": "verl.workers.rollout.trtllm_rollout.trtllm_worker_extension.WorkerExtension", } - if self.is_vlm_model: - llm_kwargs.update(self_defined_extension) - else: - llm_kwargs.update( - { - "ray_worker_extension_cls": "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension", - } - ) + llm_kwargs.update(self_defined_extension) if self.is_reward_model: llm_kwargs.update( @@ -220,6 +229,11 @@ async def launch_server(self): } ) + # Experimental: force TRT-LLM to dynamically re-quantize incoming weights + # at reload time. Reads actor_rollout_ref.rollout.force_dynamic_quantization. + if getattr(self.config, "force_dynamic_quantization", False): + llm_kwargs["force_dynamic_quantization"] = True + self.llm = await AsyncLLM(**llm_kwargs) import inspect diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py index 4fbf0e614f4..6c081db3cca 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_rollout.py @@ -29,6 +29,7 @@ import ray import torch import torch.distributed as dist +from verl.workers.rollout.trtllm_rollout import _w4a4_compat # noqa: F401 (registers float8 storage stubs at import) try: from tensorrt_llm.llmapi.llm_args import ExecutorMemoryType @@ -300,6 +301,13 @@ def __init__( } fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS) model_config.hf_config.quantization_config = fp8_block_quant_kwargs + + # NVFP4 QAT — mirror actor's quant config onto hf_config so ServerAdapter + # weight-sync sees NVFP4 metadata (same role as the FP8 block above). + qat_cfg = config.get("qat", None) + if qat_cfg is not None and qat_cfg.get("enable", False) and qat_cfg.get("quantization_config_path"): + model_config.hf_config.quantization_config = _w4a4_compat.build_nvfp4_quantization_config(qat_cfg) + super().__init__(config, model_config, device_mesh) self._adapter = None self.hybrid_device_mesh = None @@ -496,8 +504,23 @@ async def flush(): ) cur_available_bytes -= size_in_bytes - handle = reduce_tensor(param.detach()) - cur_handles.append((name, handle)) + t = param.detach() + # With param_offload=True, weights may arrive on CPU; reduce_tensor + # falls back to legacy save for CPU tensors, yielding handle args + # incompatible with the receiver (which assumes rebuild_cuda_tensor + # 15-arg layout). Move to CUDA so reduce_tensor takes the IPC path. + if not t.is_cuda: + t = t.cuda() + # Float8 dtypes (e.g. W4A4 weight_scale) don't take CUDA IPC fast + # path in reduce_tensor; they fall through to legacy save which + # yields incompatible handle args. View as uint8 for transport and + # record original dtype so the receiver can view it back. + orig_dtype_str = None + if t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2): + orig_dtype_str = str(t.dtype).split(".")[-1] + t = t.view(torch.uint8) + handle = reduce_tensor(t) + cur_handles.append((name, handle, orig_dtype_str)) await flush() diff --git a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py index 0fc5a06cd8e..73d43182d2e 100644 --- a/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py +++ b/verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py @@ -15,6 +15,9 @@ import inspect from typing import Optional +import torch +from verl.workers.rollout.trtllm_rollout import _w4a4_compat # noqa: F401 (registers float8 storage stubs at import) + # Defer tensorrt_llm imports to avoid FlashInfer's check_cuda_arch() crash # when this module is loaded on CPU-only Ray actors. The module is normally # loaded only on GPU workers via string path in trtllm_async_server.py, but @@ -37,6 +40,36 @@ logger = None +# Monkey-patch tensorrt_llm.serialization.loads to permit torch.storage._load_from_bytes +# and float8 storage stubs. The parent rlhf_utils.WorkerExtension.update_weights calls +# serialization.loads with a hard-coded approved_imports that omits these — needed for +# W4A4 weight_scale (float8_e4m3fn) IPC round-trip. +if serialization is not None and not getattr(serialization, "_verl_w4a4_patched", False): + _orig_loads = serialization.loads + + def _patched_loads(data, approved_imports=None, **kwargs): + extra = { + "torch.storage": ["_load_from_bytes", "UntypedStorage", "TypedStorage", "_TypedStorage"], + "torch._utils": ["_rebuild_tensor_v2"], + "torch.multiprocessing.reductions": ["rebuild_cuda_tensor", "rebuild_tensor"], + "torch": ["Tensor", "Size", "dtype", "device", "Float8_e4m3fnStorage", "Float8_e5m2Storage", + "float8_e4m3fn", "float8_e5m2"], + } + if approved_imports is None: + approved_imports = {} + approved_imports = dict(approved_imports) + for _mod, _names in extra.items(): + merged = list(approved_imports.get(_mod, [])) + for _n in _names: + if _n not in merged: + merged.append(_n) + approved_imports[_mod] = merged + return _orig_loads(data, approved_imports=approved_imports, **kwargs) + + serialization.loads = _patched_loads + serialization._verl_w4a4_patched = True + + class WorkerExtension(TrtllmWorkerExtension): def __init__(self): pass @@ -107,6 +140,7 @@ def update_weights(self, ipc_handles: Optional[dict] = None): "device", "float32", "float16", + "bfloat16", "int32", "int64", "int16", @@ -140,11 +174,20 @@ def update_weights(self, ipc_handles: Optional[dict] = None): # Data is already in the correct format (backward compatibility) all_handles = serialized_handles - for param_name, tensor_handle in all_handles: + for entry in all_handles: + # Support both 2-tuple (legacy) and 3-tuple (with orig_dtype_str + # for float8 transport-as-uint8 round-trip). + if len(entry) == 3: + param_name, tensor_handle, orig_dtype_str = entry + else: + param_name, tensor_handle = entry + orig_dtype_str = None func, args = tensor_handle list_args = list(args) list_args[6] = self.device_id tensor = func(*list_args) + if orig_dtype_str is not None: + tensor = tensor.view(getattr(torch, orig_dtype_str)) weights[param_name] = tensor logger.info(f"weights key size: {len(weights.keys())}")