Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion recipe
Submodule recipe updated 131 files
4 changes: 4 additions & 0 deletions verl/trainer/config/rollout/rollout.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -446,3 +446,7 @@ mtp: ${oc.select:actor_rollout_ref.model.mtp, null}

# QAT configuration (inherited from actor's engine config)
qat: ${oc.select:actor_rollout_ref.actor.fsdp_config.qat,${oc.select:actor_rollout_ref.actor.megatron.qat,null}}

# Experimental: force TRT-LLM to dynamically re-quantize incoming weights at reload time.
# Only consumed by the TRT-LLM rollout backend (see workers/rollout/trtllm_rollout/trtllm_async_server.py).
force_dynamic_quantization: False
3 changes: 2 additions & 1 deletion verl/utils/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,9 @@ def __init__(
self._fusion_siblings_ref = None

if mode == QATMode.W4A4:
# Seed to 1.0 (not the -1.0 sentinel) so step-0 update_weights passes before any actor forward runs to populate the real scale.
self.register_buffer(
"input_global_scale", torch.tensor([self._UNINITIALIZED_SCALE], dtype=torch.float32), persistent=True
"input_global_scale", torch.tensor([1.0], dtype=torch.float32), persistent=True
)

self.register_buffer(
Expand Down
2 changes: 2 additions & 0 deletions verl/workers/config/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,8 @@ class RolloutConfig(BaseConfig):

qat: Optional[dict] = None

force_dynamic_quantization: bool = False

def __post_init__(self):
"""Validate the rollout config"""
# Deprecation warning for mode field - only async mode is supported
Expand Down
74 changes: 74 additions & 0 deletions verl/workers/rollout/trtllm_rollout/_w4a4_compat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
# Copyright 2026 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Internal helpers shared by the verl-side TRT-LLM rollout adapter and the
TRT-LLM rollout worker extension for W4A4 (NVFP4 QAT) weight sync over IPC."""

import json
import torch


def _register_float8_storage_stubs() -> None:
"""Register float8_e4m3fn / float8_e5m2 in torch's legacy storage pickle
maps and create UntypedStorage stub classes so the legacy save/load path
round-trips W4A4 weight_scale tensors via IPC.

Must run on both the sender (verl rollout worker) and receiver (TRT-LLM
rollout worker). Safe to call multiple times.
"""
try:
from torch import storage as torch_storage
fwd = torch_storage._dtype_to_storage_type_map
bwd = torch_storage._storage_type_to_dtype_map
if callable(fwd):
fwd = fwd()
if callable(bwd):
bwd = bwd()
for dt, name in (
(torch.float8_e4m3fn, "Float8_e4m3fnStorage"),
(torch.float8_e5m2, "Float8_e5m2Storage"),
):
if dt not in fwd:
fwd[dt] = name
if name not in bwd:
bwd[name] = dt
if not hasattr(torch, name):
stub = type(name, (torch.UntypedStorage,), {"dtype": dt, "__module__": "torch"})
setattr(torch, name, stub)
except (AttributeError, ImportError):
pass


def build_nvfp4_quantization_config(qat_cfg) -> dict:
"""Build the HF/TRT-LLM quantization_config dict from a verl QAT config.

`qat_cfg` is the rollout-side QAT sub-config (OmegaConf-style with
`.get(...)`). Returns the dict that goes into either
`engine_kwargs["model_kwargs"]["quantization_config"]` (TRT-LLM side) or
`model_config.hf_config.quantization_config` (verl ServerAdapter side).
"""
quant_json_path = qat_cfg.get("quantization_config_path")
with open(quant_json_path) as f:
quant_json = json.load(f)
group_size = (
(quant_json.get("config_groups", {}).get("group_0", {}).get("weights", {}) or {})
.get("group_size", 16)
)
return {
"quant_method": "nvfp4",
"group_size": group_size,
"modules_to_not_convert": quant_json.get("ignore", []) or ["lm_head"],
}


_register_float8_storage_stubs()
30 changes: 22 additions & 8 deletions verl/workers/rollout/trtllm_rollout/trtllm_async_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,17 @@ async def launch_server(self):
else:
raise ValueError(f"Currently only support fp8 quantization, got: {quantization}")

# NVFP4 QAT — actor (Megatron+modelopt) emits real-quantized NVFP4 weights;
# rollout must build NVFP4-packed linears to receive them.
qat_cfg = getattr(self.config, "qat", None)
if qat_cfg is not None and qat_cfg.get("enable", False):
from verl.workers.rollout.trtllm_rollout._w4a4_compat import build_nvfp4_quantization_config
engine_kwargs.setdefault("model_kwargs", {})["quantization_config"] = (
build_nvfp4_quantization_config(qat_cfg)
)
if self.config.load_format != "dummy":
raise ValueError("NVFP4 QAT rollout requires load_format=dummy")

llm_kwargs = {
"model": self.model_config.local_path,
"backend": "pytorch",
Expand Down Expand Up @@ -187,17 +198,15 @@ async def launch_server(self):
**engine_kwargs,
}

# Always use verl's WorkerExtension. The module-level imports register
# float8 storage stubs and monkey-patch tensorrt_llm.serialization.loads
# to allow torch.storage._load_from_bytes — required for W4A4 weight_scale
# (float8_e4m3fn) IPC. Without this, non-VLM models fell back to the
# parent TRT-LLM extension and the patches never ran on the server.
self_defined_extension = {
"ray_worker_extension_cls": "verl.workers.rollout.trtllm_rollout.trtllm_worker_extension.WorkerExtension",
}
if self.is_vlm_model:
llm_kwargs.update(self_defined_extension)
else:
llm_kwargs.update(
{
"ray_worker_extension_cls": "tensorrt_llm.llmapi.rlhf_utils.WorkerExtension",
}
)
llm_kwargs.update(self_defined_extension)

if self.is_reward_model:
llm_kwargs.update(
Expand All @@ -220,6 +229,11 @@ async def launch_server(self):
}
)

# Experimental: force TRT-LLM to dynamically re-quantize incoming weights
# at reload time. Reads actor_rollout_ref.rollout.force_dynamic_quantization.
if getattr(self.config, "force_dynamic_quantization", False):
llm_kwargs["force_dynamic_quantization"] = True

self.llm = await AsyncLLM(**llm_kwargs)
import inspect

Expand Down
27 changes: 25 additions & 2 deletions verl/workers/rollout/trtllm_rollout/trtllm_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import ray
import torch
import torch.distributed as dist
from verl.workers.rollout.trtllm_rollout import _w4a4_compat # noqa: F401 (registers float8 storage stubs at import)

try:
from tensorrt_llm.llmapi.llm_args import ExecutorMemoryType
Expand Down Expand Up @@ -300,6 +301,13 @@ def __init__(
}
fp8_block_quant_kwargs = dict(FP8_BLOCK_QUANT_KWARGS)
model_config.hf_config.quantization_config = fp8_block_quant_kwargs

# NVFP4 QAT — mirror actor's quant config onto hf_config so ServerAdapter
# weight-sync sees NVFP4 metadata (same role as the FP8 block above).
qat_cfg = config.get("qat", None)
if qat_cfg is not None and qat_cfg.get("enable", False) and qat_cfg.get("quantization_config_path"):
model_config.hf_config.quantization_config = _w4a4_compat.build_nvfp4_quantization_config(qat_cfg)

super().__init__(config, model_config, device_mesh)
self._adapter = None
self.hybrid_device_mesh = None
Expand Down Expand Up @@ -496,8 +504,23 @@ async def flush():
)
cur_available_bytes -= size_in_bytes

handle = reduce_tensor(param.detach())
cur_handles.append((name, handle))
t = param.detach()
# With param_offload=True, weights may arrive on CPU; reduce_tensor
# falls back to legacy save for CPU tensors, yielding handle args
# incompatible with the receiver (which assumes rebuild_cuda_tensor
# 15-arg layout). Move to CUDA so reduce_tensor takes the IPC path.
if not t.is_cuda:
t = t.cuda()
# Float8 dtypes (e.g. W4A4 weight_scale) don't take CUDA IPC fast
# path in reduce_tensor; they fall through to legacy save which
# yields incompatible handle args. View as uint8 for transport and
# record original dtype so the receiver can view it back.
orig_dtype_str = None
if t.dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
orig_dtype_str = str(t.dtype).split(".")[-1]
t = t.view(torch.uint8)
handle = reduce_tensor(t)
cur_handles.append((name, handle, orig_dtype_str))

await flush()

Expand Down
45 changes: 44 additions & 1 deletion verl/workers/rollout/trtllm_rollout/trtllm_worker_extension.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
import inspect
from typing import Optional

import torch
from verl.workers.rollout.trtllm_rollout import _w4a4_compat # noqa: F401 (registers float8 storage stubs at import)

# Defer tensorrt_llm imports to avoid FlashInfer's check_cuda_arch() crash
# when this module is loaded on CPU-only Ray actors. The module is normally
# loaded only on GPU workers via string path in trtllm_async_server.py, but
Expand All @@ -37,6 +40,36 @@
logger = None


# Monkey-patch tensorrt_llm.serialization.loads to permit torch.storage._load_from_bytes
# and float8 storage stubs. The parent rlhf_utils.WorkerExtension.update_weights calls
# serialization.loads with a hard-coded approved_imports that omits these — needed for
# W4A4 weight_scale (float8_e4m3fn) IPC round-trip.
if serialization is not None and not getattr(serialization, "_verl_w4a4_patched", False):
_orig_loads = serialization.loads

def _patched_loads(data, approved_imports=None, **kwargs):
extra = {
"torch.storage": ["_load_from_bytes", "UntypedStorage", "TypedStorage", "_TypedStorage"],
"torch._utils": ["_rebuild_tensor_v2"],
"torch.multiprocessing.reductions": ["rebuild_cuda_tensor", "rebuild_tensor"],
"torch": ["Tensor", "Size", "dtype", "device", "Float8_e4m3fnStorage", "Float8_e5m2Storage",
"float8_e4m3fn", "float8_e5m2"],
}
if approved_imports is None:
approved_imports = {}
approved_imports = dict(approved_imports)
for _mod, _names in extra.items():
merged = list(approved_imports.get(_mod, []))
for _n in _names:
if _n not in merged:
merged.append(_n)
approved_imports[_mod] = merged
return _orig_loads(data, approved_imports=approved_imports, **kwargs)

serialization.loads = _patched_loads
serialization._verl_w4a4_patched = True


class WorkerExtension(TrtllmWorkerExtension):
def __init__(self):
pass
Expand Down Expand Up @@ -107,6 +140,7 @@ def update_weights(self, ipc_handles: Optional[dict] = None):
"device",
"float32",
"float16",
"bfloat16",
"int32",
"int64",
"int16",
Expand Down Expand Up @@ -140,11 +174,20 @@ def update_weights(self, ipc_handles: Optional[dict] = None):
# Data is already in the correct format (backward compatibility)
all_handles = serialized_handles

for param_name, tensor_handle in all_handles:
for entry in all_handles:
# Support both 2-tuple (legacy) and 3-tuple (with orig_dtype_str
# for float8 transport-as-uint8 round-trip).
if len(entry) == 3:
param_name, tensor_handle, orig_dtype_str = entry
else:
param_name, tensor_handle = entry
orig_dtype_str = None
func, args = tensor_handle
list_args = list(args)
list_args[6] = self.device_id
tensor = func(*list_args)
if orig_dtype_str is not None:
tensor = tensor.view(getattr(torch, orig_dtype_str))
weights[param_name] = tensor

logger.info(f"weights key size: {len(weights.keys())}")
Expand Down