From 70583780cbce611bd025f5b0228048c2912b9209 Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Mon, 25 May 2026 23:07:01 -0700 Subject: [PATCH 1/2] [megatron] feat: support DeepSeek V4 GRPO Adds DeepSeek V4 Flash GRPO support with Megatron-Bridge actor/ref workers, vLLM rollout, FP8/MXFP4 weight transfer handling, and checkpoint save/export verification. Signed-off-by: Hollow Man --- examples/grpo_trainer/README.md | 3 +- .../run_deepseek_v4_flash_megatron.sh | 241 ++++++ tests/utils/test_bucketed_weight_transfer.py | 9 + tests/utils/test_vllm_fp8_utils.py | 124 +++ .../config/test_model_config_on_cpu.py | 29 +- verl/utils/vllm/vllm_fp8_utils.py | 742 ++++++++++++++---- verl/workers/config/model.py | 27 + .../engine/megatron/transformer_impl.py | 4 + .../vllm_rollout/bucketed_weight_transfer.py | 19 +- verl/workers/rollout/vllm_rollout/utils.py | 39 +- .../rollout/vllm_rollout/vllm_async_server.py | 16 + 11 files changed, 1085 insertions(+), 168 deletions(-) create mode 100644 examples/grpo_trainer/run_deepseek_v4_flash_megatron.sh create mode 100644 tests/utils/test_vllm_fp8_utils.py diff --git a/examples/grpo_trainer/README.md b/examples/grpo_trainer/README.md index 7ad8f4c2436..c09f46723b3 100644 --- a/examples/grpo_trainer/README.md +++ b/examples/grpo_trainer/README.md @@ -48,7 +48,7 @@ run__[_].sh Where: - `` is the canonical size for a model family (`qwen3_8b` for dense text, `qwen3_30b_a3b` for MoE, `qwen2_5_vl_7b` / `qwen3_vl_8b` for vision, - `qwen3_235b_a22b` / `deepseek_v3_671b` for scale demos). + `qwen3_235b_a22b` / `deepseek_v3_671b` / `deepseek_v4_flash` for scale demos). - `` ∈ {`fsdp`, `megatron`, `mindspeed`}. - `` is used only for hardware-specific variants such as `gb200`, `fp8`, `veomni`, or MindSpeed NPU scripts. @@ -88,6 +88,7 @@ bash examples/grpo_trainer/run_qwen3_8b_fsdp.sh | Qwen3.5-35B-A3B (MoE) | | ✓ | | VeOmni | nvidia | | Qwen3.5-122B-A10B | ✓ | | | Megatron | nvidia | | DeepSeek-V3 671B | ✓ | | | Megatron | nvidia | +| DeepSeek-V4-Flash | ✓ | | | Megatron | nvidia | | GLM-4.1V-9B | ✓ | | | FSDP | nvidia | | MiniCPM-o-2.6 | ✓ | | | FSDP | nvidia | | Moonlight-16B-A3B | ✓ | | | Megatron | nvidia | diff --git a/examples/grpo_trainer/run_deepseek_v4_flash_megatron.sh b/examples/grpo_trainer/run_deepseek_v4_flash_megatron.sh new file mode 100644 index 00000000000..050c2a8aa4a --- /dev/null +++ b/examples/grpo_trainer/run_deepseek_v4_flash_megatron.sh @@ -0,0 +1,241 @@ +#!/usr/bin/env bash +# GRPO | DeepSeek-V4-Flash | vLLM rollout | Megatron training | NVIDIA GPUs +# +# Prerequisites on every node: +# CUDA_DEVICE_MAX_CONNECTIONS=1 +# NCCL_NVLS_ENABLE=0 +# VLLM_USE_V1=1 +# Megatron-Bridge/mbridge installed, or set MBRIDGE_PATH=/path/to/Megatron-Bridge. +# Minimum 8 nodes x 8x 80GB+ GPUs recommended for the default smoke settings. + +set -xeuo pipefail + +export CUDA_DEVICE_MAX_CONNECTIONS=1 +export NCCL_NVLS_ENABLE=0 +export TORCH_NCCL_AVOID_RECORD_STREAMS=1 +export VLLM_USE_V1=1 +export VLLM_DISABLE_COMPILE_CACHE=${VLLM_DISABLE_COMPILE_CACHE:-1} +export HYDRA_FULL_ERROR=1 +export PYTHONUNBUFFERED=1 +export PYTORCH_CUDA_ALLOC_CONF=${PYTORCH_CUDA_ALLOC_CONF:-expandable_segments:True} + + +############################### configs ################################ + +MODEL_PATH=${MODEL_PATH:-deepseek-ai/DeepSeek-V4-Flash} +NNODES=${NNODES:-8} +NGPUS_PER_NODE=${NGPUS_PER_NODE:-8} + +TRAIN_BATCH_SIZE=${TRAIN_BATCH_SIZE:-8} +PPO_MINI_BATCH_SIZE=${PPO_MINI_BATCH_SIZE:-4} +ACTOR_PPO_MICRO_BATCH_SIZE_PER_GPU=${ACTOR_PPO_MICRO_BATCH_SIZE_PER_GPU:-1} +DATALOADER_NUM_WORKERS=${DATALOADER_NUM_WORKERS:-8} +MAX_PROMPT_LENGTH=${MAX_PROMPT_LENGTH:-512} +MAX_RESPONSE_LENGTH=${MAX_RESPONSE_LENGTH:-128} +PPO_MAX_TOKEN_LEN_PER_GPU=${PPO_MAX_TOKEN_LEN_PER_GPU:-$((MAX_PROMPT_LENGTH + MAX_RESPONSE_LENGTH))} + +ACTOR_LR=${ACTOR_LR:-1e-6} +OPTIMIZER_OFFLOAD_FRACTION=${OPTIMIZER_OFFLOAD_FRACTION:-1.0} + +ACTOR_TP=${ACTOR_TP:-1} +ACTOR_PP=${ACTOR_PP:-8} +ACTOR_VPP=${ACTOR_VPP:-null} +ACTOR_EP=${ACTOR_EP:-8} +ACTOR_ETP=${ACTOR_ETP:-1} +ACTOR_CP=${ACTOR_CP:-1} +PIPELINE_MODEL_PARALLEL_LAYOUT=${PIPELINE_MODEL_PARALLEL_LAYOUT:-"Et*6|t*6|t*6|t*5|t*5|t*5|t*5|t*5L"} + +REF_TP=${REF_TP:-${ACTOR_TP}} +REF_PP=${REF_PP:-${ACTOR_PP}} +REF_VPP=${REF_VPP:-${ACTOR_VPP}} +REF_EP=${REF_EP:-${ACTOR_EP}} +REF_ETP=${REF_ETP:-${ACTOR_ETP}} +REF_CP=${REF_CP:-${ACTOR_CP}} + +ROLLOUT_TP=${ROLLOUT_TP:-8} +ROLLOUT_N=${ROLLOUT_N:-2} +ROLLOUT_GPU_MEM_UTIL=${ROLLOUT_GPU_MEM_UTIL:-0.275} +ROLLOUT_MAX_MODEL_LEN=${ROLLOUT_MAX_MODEL_LEN:-${PPO_MAX_TOKEN_LEN_PER_GPU}} +ROLLOUT_MAX_NUM_BATCHED_TOKENS=${ROLLOUT_MAX_NUM_BATCHED_TOKENS:-${PPO_MAX_TOKEN_LEN_PER_GPU}} +ROLLOUT_MAX_NUM_SEQS=${ROLLOUT_MAX_NUM_SEQS:-8} +ROLLOUT_KV_CACHE_DTYPE=${ROLLOUT_KV_CACHE_DTYPE:-fp8} +ROLLOUT_UPDATE_WEIGHTS_BUCKET_MB=${ROLLOUT_UPDATE_WEIGHTS_BUCKET_MB:-512} + +ALL_OFFLOAD=${ALL_OFFLOAD:-True} +TOTAL_EPOCHS=${TOTAL_EPOCHS:-1} +TOTAL_TRAINING_STEPS=${TOTAL_TRAINING_STEPS:-1} +SAVE_FREQ=${SAVE_FREQ:--1} +TEST_FREQ=${TEST_FREQ:--1} + +PROJECT_NAME=${PROJECT_NAME:-verl_grpo} +EXPERIMENT_NAME=${EXPERIMENT_NAME:-deepseek_v4_flash_grpo_vllm_megatron} +CKPTS_DIR=${CKPTS_DIR:-"${HOME}/verl/ckpts/${PROJECT_NAME}/${EXPERIMENT_NAME}"} + +TRAIN_FILE=${TRAIN_FILE:-$HOME/data/dapo-math-17k.parquet} +TEST_FILE=${TEST_FILE:-$HOME/data/aime-2024.parquet} +CHAT_TEMPLATE_FILE=${CHAT_TEMPLATE_FILE:-"${CKPTS_DIR}/deepseek_v4_flash_chat_template.jinja"} + +OVERLONG_BUFFER_LEN=${OVERLONG_BUFFER_LEN:-${MAX_RESPONSE_LENGTH}} +OVERLONG_BUFFER_ENABLE=${OVERLONG_BUFFER_ENABLE:-False} +OVERLONG_PENALTY_FACTOR=${OVERLONG_PENALTY_FACTOR:-1.0} + +########################### parameter arrays ########################### + +DEFAULT_CHAT_TEMPLATE='{% for message in messages %}{% if message["content"] is string %}{{ message["content"] }}{% else %}{% for content in message["content"] %}{% if content["type"] == "text" %}{{ content["text"] }}{% endif %}{% endfor %}{% endif %}{% if not loop.last %}{{ "\n\n" }}{% endif %}{% endfor %}{% if add_generation_prompt %}{{ "\n" }}{% endif %}' +if [ ! -f "${CHAT_TEMPLATE_FILE}" ]; then + mkdir -p "$(dirname "${CHAT_TEMPLATE_FILE}")" + printf "%s\n" "${DEEPSEEK_V4_FLASH_CHAT_TEMPLATE:-${DEFAULT_CHAT_TEMPLATE}}" >"${CHAT_TEMPLATE_FILE}" +fi + +ALGORITHM=( + algorithm.adv_estimator=grpo + algorithm.use_kl_in_reward=False + algorithm.kl_ctrl.kl_coef=0.0 +) + +DATA=( + data.train_files="$TRAIN_FILE" + data.val_files="$TEST_FILE" + data.train_batch_size=${TRAIN_BATCH_SIZE} + data.prompt_key=prompt + data.return_raw_chat=True + data.max_prompt_length=${MAX_PROMPT_LENGTH} + data.max_response_length=${MAX_RESPONSE_LENGTH} + data.filter_overlong_prompts=False + data.truncation=left + data.dataloader_num_workers=${DATALOADER_NUM_WORKERS} +) + +MODEL=( + actor_rollout_ref.model.path="$MODEL_PATH" + actor_rollout_ref.model.trust_remote_code=True + actor_rollout_ref.model.custom_chat_template="@${CHAT_TEMPLATE_FILE}" + actor_rollout_ref.model.use_fused_kernels=False + actor_rollout_ref.model.use_remove_padding=False + actor_rollout_ref.model.enable_gradient_checkpointing=True +) + +ACTOR=( + actor_rollout_ref.actor.optim.lr=${ACTOR_LR} + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_offload_fraction=${OPTIMIZER_OFFLOAD_FRACTION} + +actor_rollout_ref.actor.optim.override_optimizer_config.overlap_cpu_optimizer_d2h_h2d=True + +actor_rollout_ref.actor.optim.override_optimizer_config.use_precision_aware_optimizer=True + +actor_rollout_ref.actor.optim.override_optimizer_config.optimizer_cpu_offload=True + actor_rollout_ref.actor.ppo_mini_batch_size=${PPO_MINI_BATCH_SIZE} + actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=${ACTOR_PPO_MICRO_BATCH_SIZE_PER_GPU} + actor_rollout_ref.actor.use_dynamic_bsz=True + actor_rollout_ref.actor.ppo_max_token_len_per_gpu=${PPO_MAX_TOKEN_LEN_PER_GPU} + actor_rollout_ref.actor.use_kl_loss=False + actor_rollout_ref.actor.kl_loss_coef=0.0 + actor_rollout_ref.actor.kl_loss_type=low_var_kl + actor_rollout_ref.actor.entropy_coeff=0 + actor_rollout_ref.actor.loss_agg_mode=token-mean + actor_rollout_ref.actor.megatron.tensor_model_parallel_size=${ACTOR_TP} + actor_rollout_ref.actor.megatron.pipeline_model_parallel_size=${ACTOR_PP} + actor_rollout_ref.actor.megatron.virtual_pipeline_model_parallel_size=${ACTOR_VPP} + actor_rollout_ref.actor.megatron.expert_model_parallel_size=${ACTOR_EP} + actor_rollout_ref.actor.megatron.expert_tensor_parallel_size=${ACTOR_ETP} + actor_rollout_ref.actor.megatron.context_parallel_size=${ACTOR_CP} + actor_rollout_ref.actor.megatron.param_offload=${ALL_OFFLOAD} + actor_rollout_ref.actor.megatron.optimizer_offload=${ALL_OFFLOAD} + actor_rollout_ref.actor.megatron.grad_offload=${ALL_OFFLOAD} + actor_rollout_ref.actor.megatron.use_mbridge=True + actor_rollout_ref.actor.megatron.vanilla_mbridge=False + ++actor_rollout_ref.actor.megatron.override_transformer_config.use_fused_mhc=False + ++actor_rollout_ref.actor.megatron.override_transformer_config.apply_rope_fusion=False + ++actor_rollout_ref.actor.megatron.override_transformer_config.dsa_indexer_loss_coeff=0.0 + ++actor_rollout_ref.actor.megatron.override_transformer_config.recompute_method=uniform + ++actor_rollout_ref.actor.megatron.override_transformer_config.recompute_granularity=full + ++actor_rollout_ref.actor.megatron.override_transformer_config.recompute_num_layers=1 + "++actor_rollout_ref.actor.megatron.override_transformer_config.pipeline_model_parallel_layout='${PIPELINE_MODEL_PARALLEL_LAYOUT}'" +) + +ROLLOUT=( + actor_rollout_ref.rollout.name=vllm + actor_rollout_ref.rollout.tensor_model_parallel_size=${ROLLOUT_TP} + actor_rollout_ref.rollout.gpu_memory_utilization=${ROLLOUT_GPU_MEM_UTIL} + actor_rollout_ref.rollout.n=${ROLLOUT_N} + actor_rollout_ref.rollout.calculate_log_probs=True + actor_rollout_ref.rollout.enable_chunked_prefill=True + actor_rollout_ref.rollout.enforce_eager=True + actor_rollout_ref.rollout.log_prob_use_dynamic_bsz=True + actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=${PPO_MAX_TOKEN_LEN_PER_GPU} + actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=1 + actor_rollout_ref.rollout.max_num_seqs=${ROLLOUT_MAX_NUM_SEQS} + actor_rollout_ref.rollout.max_num_batched_tokens=${ROLLOUT_MAX_NUM_BATCHED_TOKENS} + actor_rollout_ref.rollout.max_model_len=${ROLLOUT_MAX_MODEL_LEN} + actor_rollout_ref.rollout.checkpoint_engine.update_weights_bucket_megabytes=${ROLLOUT_UPDATE_WEIGHTS_BUCKET_MB} + +actor_rollout_ref.rollout.engine_kwargs.vllm.kv_cache_dtype=${ROLLOUT_KV_CACHE_DTYPE} + actor_rollout_ref.rollout.prompt_length=${MAX_PROMPT_LENGTH} + actor_rollout_ref.rollout.response_length=${MAX_RESPONSE_LENGTH} + actor_rollout_ref.rollout.temperature=1.0 + actor_rollout_ref.rollout.top_p=1.0 + actor_rollout_ref.rollout.val_kwargs.do_sample=True + actor_rollout_ref.rollout.val_kwargs.temperature=1.0 + actor_rollout_ref.rollout.val_kwargs.top_p=0.7 + actor_rollout_ref.rollout.val_kwargs.n=1 +) + +REF=( + actor_rollout_ref.ref.log_prob_use_dynamic_bsz=True + actor_rollout_ref.ref.log_prob_max_token_len_per_gpu=${PPO_MAX_TOKEN_LEN_PER_GPU} + actor_rollout_ref.ref.log_prob_micro_batch_size_per_gpu=1 + actor_rollout_ref.ref.megatron.tensor_model_parallel_size=${REF_TP} + actor_rollout_ref.ref.megatron.pipeline_model_parallel_size=${REF_PP} + actor_rollout_ref.ref.megatron.virtual_pipeline_model_parallel_size=${REF_VPP} + actor_rollout_ref.ref.megatron.expert_model_parallel_size=${REF_EP} + actor_rollout_ref.ref.megatron.expert_tensor_parallel_size=${REF_ETP} + actor_rollout_ref.ref.megatron.context_parallel_size=${REF_CP} + actor_rollout_ref.ref.megatron.param_offload=${ALL_OFFLOAD} + actor_rollout_ref.ref.megatron.use_mbridge=True + actor_rollout_ref.ref.megatron.vanilla_mbridge=False + ++actor_rollout_ref.ref.megatron.override_transformer_config.use_fused_mhc=False + ++actor_rollout_ref.ref.megatron.override_transformer_config.apply_rope_fusion=False + ++actor_rollout_ref.ref.megatron.override_transformer_config.dsa_indexer_loss_coeff=0.0 + "++actor_rollout_ref.ref.megatron.override_transformer_config.pipeline_model_parallel_layout='${PIPELINE_MODEL_PARALLEL_LAYOUT}'" +) + +REWARD=( + reward.reward_manager.name=dapo + +reward.reward_kwargs.overlong_buffer_cfg.enable=${OVERLONG_BUFFER_ENABLE} + +reward.reward_kwargs.overlong_buffer_cfg.len=${OVERLONG_BUFFER_LEN} + +reward.reward_kwargs.overlong_buffer_cfg.penalty_factor=${OVERLONG_PENALTY_FACTOR} + +reward.reward_kwargs.overlong_buffer_cfg.log=False + +reward.reward_kwargs.max_resp_len=${MAX_RESPONSE_LENGTH} +) + +TRAINER=( + trainer.balance_batch=True + trainer.logger='["console","wandb"]' + trainer.project_name=${PROJECT_NAME} + trainer.experiment_name=${EXPERIMENT_NAME} + trainer.n_gpus_per_node=${NGPUS_PER_NODE} + trainer.nnodes=${NNODES} + trainer.save_freq=${SAVE_FREQ} + trainer.test_freq=${TEST_FREQ} + trainer.total_epochs=${TOTAL_EPOCHS} + trainer.total_training_steps=${TOTAL_TRAINING_STEPS} + trainer.resume_mode=disable + trainer.val_before_train=False + trainer.log_val_generations=0 + trainer.default_local_dir="${CKPTS_DIR}" +) + +EXTRA=( + actor_rollout_ref.nccl_timeout=3600 + model_engine=megatron +) + +########################### launch ########################### + +python3 -m verl.trainer.main_ppo \ + "${ALGORITHM[@]}" \ + "${DATA[@]}" \ + "${MODEL[@]}" \ + "${ACTOR[@]}" \ + "${ROLLOUT[@]}" \ + "${REF[@]}" \ + "${REWARD[@]}" \ + "${TRAINER[@]}" \ + "${EXTRA[@]}" \ + "$@" diff --git a/tests/utils/test_bucketed_weight_transfer.py b/tests/utils/test_bucketed_weight_transfer.py index 7a3fbdb127a..1ba20a0e77d 100644 --- a/tests/utils/test_bucketed_weight_transfer.py +++ b/tests/utils/test_bucketed_weight_transfer.py @@ -58,6 +58,15 @@ def _generate_weights(weight_specs, seed): return weights +def test_align_offset_respects_tensor_element_size(): + from verl.workers.rollout.vllm_rollout.bucketed_weight_transfer import _align_offset + + assert _align_offset(0, 2) == 0 + assert _align_offset(1, 2) == 2 + assert _align_offset(3, 4) == 4 + assert _align_offset(8, 4) == 8 + + # --------------------------------------------------------------------------- # Process entry points (must be module-level for pickling with spawn) # --------------------------------------------------------------------------- diff --git a/tests/utils/test_vllm_fp8_utils.py b/tests/utils/test_vllm_fp8_utils.py new file mode 100644 index 00000000000..7ae619c4af1 --- /dev/null +++ b/tests/utils/test_vllm_fp8_utils.py @@ -0,0 +1,124 @@ +# Copyright 2026 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import torch + + +def _import_quantize_mxfp4_weight(): + pytest.importorskip("vllm") + + try: + from verl.utils.vllm.vllm_fp8_utils import quantize_mxfp4_weight + except ImportError as exc: + pytest.skip(f"vLLM FP8 utilities are unavailable: {exc}") + + return quantize_mxfp4_weight + + +def _import_vllm_fp8_utils(): + pytest.importorskip("vllm") + + try: + from verl.utils.vllm import vllm_fp8_utils + except ImportError as exc: + pytest.skip(f"vLLM FP8 utilities are unavailable: {exc}") + + return vllm_fp8_utils + + +def test_quantize_mxfp4_weight_packs_two_values_per_byte(): + quantize_mxfp4_weight = _import_quantize_mxfp4_weight() + + weight = torch.arange(64, dtype=torch.float32).reshape(2, 32) + + quant_weight, quant_scale = quantize_mxfp4_weight(weight) + + assert quant_weight.dtype == torch.uint8 + assert quant_scale.dtype == torch.uint8 + assert quant_weight.shape == (2, 16) + assert quant_scale.shape == (2, 1) + + +def test_quantize_mxfp4_weight_requires_32_value_blocks(): + quantize_mxfp4_weight = _import_quantize_mxfp4_weight() + + with pytest.raises(ValueError, match="divisible by 32"): + quantize_mxfp4_weight(torch.ones(31)) + + +def test_prequantized_mxfp4_detection_includes_signed_packed_weights(): + vllm_fp8_utils = _import_vllm_fp8_utils() + + assert vllm_fp8_utils._is_prequantized_mxfp4_tensor(torch.empty(1, dtype=torch.int8)) + assert vllm_fp8_utils._is_prequantized_mxfp4_tensor(torch.empty(1, dtype=torch.uint8)) + assert not vllm_fp8_utils._is_prequantized_mxfp4_tensor(torch.empty(1, dtype=torch.bfloat16)) + + +def test_prequantized_fp8_detection_keeps_loaded_fp8_weights_intact(): + vllm_fp8_utils = _import_vllm_fp8_utils() + + assert vllm_fp8_utils._is_prequantized_fp8_tensor(torch.empty(1, dtype=torch.float8_e4m3fn)) + assert not vllm_fp8_utils._is_prequantized_fp8_tensor(torch.empty(1, dtype=torch.bfloat16)) + + +def test_deepseek_v4_scale_name_uses_sibling_scale_suffix(): + vllm_fp8_utils = _import_vllm_fp8_utils() + + class Config: + model_type = "deepseek_v4" + + class Model: + config = Config() + + assert ( + vllm_fp8_utils._scale_name_for_weight( + "layers.0.ffn.experts.0.w1.weight", + Model(), + use_scale_not_scale_inv=True, + ) + == "layers.0.ffn.experts.0.w1.scale" + ) + + +def test_forced_scale_name_keeps_mxfp4_suffix_for_other_models(): + vllm_fp8_utils = _import_vllm_fp8_utils() + + class Model: + pass + + assert ( + vllm_fp8_utils._scale_name_for_weight( + "model.layers.0.mlp.experts.0.gate_proj.weight", + Model(), + force_scale=True, + ) + == "model.layers.0.mlp.experts.0.gate_proj.weight_scale" + ) + + +def test_default_scale_name_keeps_vllm_fp8_suffix_convention(): + vllm_fp8_utils = _import_vllm_fp8_utils() + + class Model: + pass + + assert ( + vllm_fp8_utils._scale_name_for_weight( + "model.layers.0.mlp.experts.0.gate_proj.weight", + Model(), + use_scale_not_scale_inv=True, + ) + == "model.layers.0.mlp.experts.0.gate_proj.weight_scale_inv" + ) diff --git a/tests/workers/config/test_model_config_on_cpu.py b/tests/workers/config/test_model_config_on_cpu.py index e76985278ac..91f516efbc8 100644 --- a/tests/workers/config/test_model_config_on_cpu.py +++ b/tests/workers/config/test_model_config_on_cpu.py @@ -17,7 +17,7 @@ import pytest from omegaconf import OmegaConf -from verl.workers.config.model import HFModelConfig +from verl.workers.config.model import HFModelConfig, _resolve_custom_chat_template class TestHFModelConfigCPU: @@ -94,3 +94,30 @@ def test_target_modules_raises_on_invalid_type(self): merged_config = OmegaConf.merge(base_config, invalid_cli_config) with pytest.raises(TypeError): OmegaConf.to_object(merged_config) + + def test_resolve_custom_chat_template_from_file(self, tmp_path): + template_path = tmp_path / "template.jinja" + template_path.write_text("{{ messages[0]['content'] }}", encoding="utf-8") + + assert _resolve_custom_chat_template(f"@{template_path}") == "{{ messages[0]['content'] }}" + + def test_resolve_custom_chat_template_from_env(self, monkeypatch): + monkeypatch.setenv("VERL_TEST_CHAT_TEMPLATE", "{{ bos_token }}") + + assert _resolve_custom_chat_template("env:VERL_TEST_CHAT_TEMPLATE") == "{{ bos_token }}" + assert _resolve_custom_chat_template("${oc.env:VERL_TEST_CHAT_TEMPLATE}") == "{{ bos_token }}" + + def test_custom_chat_template_is_mutable_for_omegaconf_override(self): + assert "custom_chat_template" in HFModelConfig._mutable_fields + + cfg_from_dataclass = OmegaConf.structured(HFModelConfig) + cli_config = OmegaConf.create( + { + "path": self.model_path, + "custom_chat_template": "env:VERL_TEST_CHAT_TEMPLATE", + } + ) + + merged = OmegaConf.merge(cfg_from_dataclass, cli_config) + + assert merged.custom_chat_template == "env:VERL_TEST_CHAT_TEMPLATE" diff --git a/verl/utils/vllm/vllm_fp8_utils.py b/verl/utils/vllm/vllm_fp8_utils.py index 2477295aa3a..a0a673418a6 100644 --- a/verl/utils/vllm/vllm_fp8_utils.py +++ b/verl/utils/vllm/vllm_fp8_utils.py @@ -16,6 +16,7 @@ import inspect import logging from dataclasses import dataclass, field +from types import MethodType from unittest.mock import patch import torch @@ -28,10 +29,17 @@ except ImportError as e: raise ImportError("FP8 quantization not available") from e +from verl.utils.device import get_device_name from verl.utils.kernel.fp8_kernel import scaled_fp8_blockwise logger = logging.getLogger(__name__) +_MXFP4_E2M1_MAX = 6.0 +_MXFP4_E8M0_BIAS = 127 +_MXFP4_E8M0_MIN = 1 +_MXFP4_E8M0_MAX = 254 +_MXFP4_E2M1_THRESHOLDS = (0.25, 0.75, 1.25, 1.75, 2.5, 3.5, 5.0) + # Ref: https://github.com/NVIDIA-NeMo/RL/commit/bc24887c72a6e1b2699a228bc87c588546dfe6b7 @dataclass() @@ -46,46 +54,411 @@ class FP8State: fp8_state: FP8State = FP8State() -def is_fp8_model(vllm_config): - from vllm.model_executor.layers.quantization.fp8 import Fp8Config +def _copy_param_attributes(dst_param, src_param): + base_param_dir = dir(torch.nn.Parameter) + for attr in dir(src_param): + if attr not in base_param_dir and not attr.startswith("__"): + setattr(dst_param, attr, getattr(src_param, attr)) + + +def _create_param_from_subclass_attributes(custom_param, source_param=None): + param = torch.nn.Parameter(custom_param.data, requires_grad=False) + if source_param is not None: + _copy_param_attributes(param, source_param) + _copy_param_attributes(param, custom_param) + + param.subclass_type = type(custom_param) + return param + + +def _create_param_from_data_with_attrs(data, source_param): + param = torch.nn.Parameter(data, requires_grad=False) + _copy_param_attributes(param, source_param) + return param + + +def _get_param_weight_loader(param): + weight_loader = getattr(param, "weight_loader", None) + if weight_loader is None: + weight_loader = getattr(param, "_weight_loader", None) + return weight_loader + + +def _param_parallel_dim(param, public_name, private_name, default): + if hasattr(param, private_name): + return getattr(param, private_name) + if hasattr(param, public_name): + return getattr(param, public_name) + return default + + +def _copy_loaded_weight(param, loaded_weight): + param.data.copy_(loaded_weight.to(device=param.data.device, dtype=param.data.dtype)) + + +def _copy_weight_attrs(param, weight_loader=None, quant_method=None): + if weight_loader is not None: + param.weight_loader = weight_loader + if quant_method is not None: + param.quant_method = quant_method + return param + + +def _is_rank_zero(): + return ( + not torch.distributed.is_available() + or not torch.distributed.is_initialized() + or torch.distributed.get_rank() == 0 + ) + - if hasattr(vllm_config, "quant_config"): - if isinstance(vllm_config.quant_config, Fp8Config): +def _normalize_dim(dim, ndim): + if dim is None: + return 0 + if dim < 0: + dim += ndim + return dim + + +def _try_load_grouped_column_weight(param, loaded_weight): + tp_size = int(getattr(param, "tp_size", 1)) + tp_rank = int(getattr(param, "tp_rank", 0)) + if tp_size <= 1: + return False + + data = param.data + if loaded_weight.ndim == data.ndim + 1 and loaded_weight.shape == torch.Size((tp_size, *data.shape)): + _copy_loaded_weight(param, loaded_weight.select(0, tp_rank)) + return True + + if data.ndim >= 2 and loaded_weight.ndim == data.ndim - 1: + expected_flat_shape = (tp_size * data.shape[0] * data.shape[1], *data.shape[2:]) + if loaded_weight.shape == torch.Size(expected_flat_shape): + global_shape = (tp_size, data.shape[0], data.shape[1], *data.shape[2:]) + _copy_loaded_weight(param, loaded_weight.reshape(global_shape).select(0, tp_rank)) return True - elif is_mxfp8_vllm_ascend(vllm_config.quant_config): + + if data.ndim == loaded_weight.ndim and data.ndim > 0: + expected_dim0 = tp_size * data.shape[0] + if loaded_weight.shape[0] == expected_dim0 and loaded_weight.shape[1:] == data.shape[1:]: + start = tp_rank * data.shape[0] + _copy_loaded_weight(param, loaded_weight.narrow(0, start, data.shape[0])) return True return False +def _merged_column_offsets(param, loaded_weight, shard_offset, shard_size, loaded_shard_id): + dim = _normalize_dim(getattr(param, "output_dim", getattr(param, "_output_dim", 0)), param.data.ndim) + loaded_dim = loaded_weight.shape[dim] + offsets = [] + if isinstance(shard_offset, int): + offsets.append(shard_offset) + if isinstance(shard_size, int) and shard_size > 0: + scaled_offset = shard_offset * loaded_dim // shard_size + offsets.append(scaled_offset) + if isinstance(loaded_shard_id, int): + offsets.append(loaded_shard_id * loaded_dim) + + seen = set() + for offset in offsets: + if offset in seen: + continue + seen.add(offset) + if offset < 0 or offset + loaded_dim > param.data.shape[dim]: + continue + yield dim, offset + + +def _try_load_merged_column_weight(param, loaded_weight, shard_offset, shard_size, loaded_shard_id): + for dim, offset in _merged_column_offsets(param, loaded_weight, shard_offset, shard_size, loaded_shard_id): + target = param.data.narrow(dim, offset, loaded_weight.shape[dim]) + if target.shape == loaded_weight.shape: + target.copy_(loaded_weight.to(device=target.device, dtype=target.dtype)) + return True + return False + + +def _attach_fp8_reload_fallbacks(param): + subclass_type = getattr(param, "subclass_type", None) + if subclass_type is None or getattr(param, "_verl_fp8_reload_fallbacks", False): + return + + original_column_loader = getattr(subclass_type, "load_column_parallel_weight", None) + original_merged_loader = getattr(subclass_type, "load_merged_column_weight", None) + + if original_column_loader is not None: + + def load_column_parallel_weight(self, *args, **kwargs): + loaded_weight = kwargs.get("loaded_weight") + if loaded_weight is None and args: + loaded_weight = args[0] + if loaded_weight is not None: + if self.data.shape == loaded_weight.shape: + _copy_loaded_weight(self, loaded_weight) + return + if _try_load_grouped_column_weight(self, loaded_weight): + return + return original_column_loader(self, *args, **kwargs) + + param.load_column_parallel_weight = MethodType(load_column_parallel_weight, param) + + if original_merged_loader is not None: + + def load_merged_column_weight(self, *args, **kwargs): + loaded_weight = kwargs.get("loaded_weight") + if loaded_weight is None and args: + loaded_weight = args[0] + loaded_shard_id = kwargs.get("loaded_shard_id", kwargs.get("shard_id")) + if loaded_shard_id is None and len(args) > 1: + loaded_shard_id = args[1] + shard_offset = kwargs.get("shard_offset") + if shard_offset is None and len(args) > 2: + shard_offset = args[2] + shard_size = kwargs.get("shard_size") + if shard_size is None and len(args) > 3: + shard_size = args[3] + + if loaded_weight is not None and _try_load_merged_column_weight( + self, + loaded_weight, + shard_offset, + shard_size, + loaded_shard_id, + ): + return + return original_merged_loader(self, *args, **kwargs) + + param.load_merged_column_weight = MethodType(load_merged_column_weight, param) + + param._verl_fp8_reload_fallbacks = True + + +def _ensure_linear_params_reloadable(layer): + """Restore vLLM parameter metadata after fp8 post-processing replaces it.""" + from vllm.model_executor.parameter import ( + BlockQuantScaleParameter, + ModelWeightParameter, + ) + + if hasattr(layer, "weight") and not hasattr(layer.weight, "subclass_type"): + source_weight = layer.weight + weight_loader = _get_param_weight_loader(layer.weight) + if weight_loader is not None: + layer.weight = _create_param_from_subclass_attributes( + ModelWeightParameter( + data=source_weight.data, + output_dim=_param_parallel_dim(source_weight, "output_dim", "_output_dim", 0), + input_dim=_param_parallel_dim(source_weight, "input_dim", "_input_dim", 1), + weight_loader=weight_loader, + ), + source_weight, + ) + + for scale_name in ("weight_scale_inv", "weight_scale"): + if not hasattr(layer, scale_name): + continue + scale = getattr(layer, scale_name) + if hasattr(scale, "subclass_type"): + continue + weight_loader = _get_param_weight_loader(scale) + if weight_loader is not None: + setattr( + layer, + scale_name, + _create_param_from_subclass_attributes( + BlockQuantScaleParameter( + data=scale.data, + output_dim=_param_parallel_dim(scale, "output_dim", "_output_dim", 0), + input_dim=_param_parallel_dim(scale, "input_dim", "_input_dim", 1), + weight_loader=weight_loader, + ), + scale, + ), + ) + + update_param_tp_status = getattr(layer, "update_param_tp_status", None) + if callable(update_param_tp_status): + update_param_tp_status() + + for param_name in ("weight", "weight_scale_inv", "weight_scale"): + param = getattr(layer, param_name, None) + if param is not None: + _attach_fp8_reload_fallbacks(param) + + +def _ensure_model_params_reloadable(model): + for module in model.modules(): + if isinstance(module, LinearBase): + _ensure_linear_params_reloadable(module) + + +def _is_mxfp4_moe_module(module): + if not isinstance(module, FusedMoE): + return False + quant_method = getattr(module, "quant_method", None) + quant_method_name = type(quant_method).__name__ + return quant_method_name in ("Mxfp4MoEMethod", "GptOssMxfp4MoEMethod") or getattr( + quant_method, "weight_dtype", None + ) in ("mxfp4", "gpt_oss_mxfp4") + + +def _mxfp4_moe_load_shape(module): + quant_method = getattr(module, "quant_method", None) + shape = ( + getattr(quant_method, "num_experts", getattr(module, "local_num_experts", None)), + getattr(quant_method, "intermediate_size", getattr(module, "intermediate_size_per_partition", None)), + getattr(quant_method, "hidden_size", getattr(module, "hidden_size", None)), + ) + if any(value is None for value in shape): + return None + return tuple(int(value) for value in shape) + + +def _module_param_device(module): + for param_name in ("w13_weight", "w2_weight", "w13_weight_scale", "w2_weight_scale"): + param = getattr(module, param_name, None) + if param is not None and hasattr(param, "device"): + return param.device + return torch.device(get_device_name()) + + +def _make_mxfp4_moe_param(shape, device, weight_loader, quant_method=None): + param = torch.nn.Parameter(torch.empty(shape, dtype=torch.uint8, device=device), requires_grad=False) + return _copy_weight_attrs(param, weight_loader=weight_loader, quant_method=quant_method) + + +def _restore_mxfp4_moe_params_for_loading(model): + restored = False + for module in model.modules(): + if not _is_mxfp4_moe_module(module): + continue + load_shape = _mxfp4_moe_load_shape(module) + if load_shape is None: + continue + + num_experts, intermediate_size, hidden_size = load_shape + device = _module_param_device(module) + weight_loader = getattr(module, "weight_loader", None) + + module.w13_weight = _make_mxfp4_moe_param( + (num_experts, 2 * intermediate_size, hidden_size // 2), + device, + weight_loader, + ) + module.w2_weight = _make_mxfp4_moe_param( + (num_experts, hidden_size, intermediate_size // 2), + device, + weight_loader, + ) + module.w13_weight_scale = _make_mxfp4_moe_param( + (num_experts, 2 * intermediate_size, hidden_size // 32), + device, + weight_loader, + quant_method="block", + ) + module.w2_weight_scale = _make_mxfp4_moe_param( + (num_experts, hidden_size, intermediate_size // 32), + device, + weight_loader, + quant_method="block", + ) + + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + for attr in ("moe_kernel", "moe_quant_config", "w13_precision_config", "w2_precision_config"): + if hasattr(quant_method, attr): + setattr(quant_method, attr, None) + restored = True + return restored + + +def _process_mxfp4_moe_weights_after_loading(model): + for module in model.modules(): + if not _is_mxfp4_moe_module(module): + continue + quant_method = getattr(module, "quant_method", None) + process_weights = getattr(quant_method, "process_weights_after_loading", None) + if callable(process_weights): + process_weights(module) + + +def is_fp8_model(vllm_config): + from vllm.model_executor.layers.quantization.fp8 import Fp8Config + + quant_config = getattr(vllm_config, "quant_config", None) + return isinstance(quant_config, Fp8Config) or is_mxfp8_vllm_ascend(quant_config) + + def get_module_from_param_name(model, name: str): - # Split the name into parts (e.g., 'layers', '0', 'self_attn', 'q_proj', 'weight') - # The module path is all but the last part (the parameter's own name) - path_parts = name.split(".") - module_path = path_parts[:-1] - # Replace with the fused model name - packed_modules_mapping = model.packed_modules_mapping + if name.startswith("mtp."): + return None + + def _mapped_name(param_name: str) -> str: + mapper = getattr(model, "hf_to_vllm_mapper", None) + map_name = getattr(mapper, "_map_name", None) + if callable(map_name): + mapped = map_name(param_name) + if mapped is not None: + return mapped + return param_name + + def _candidate_module_paths(param_name: str) -> list[list[str]]: + # Split the name into parts (e.g., 'layers', '0', 'self_attn', 'q_proj', 'weight') + # The module path is all but the last part (the parameter's own name) + path_parts = _mapped_name(param_name).split(".") + if not path_parts: + return [] + module_path = path_parts[:-1] + + # Replace with the fused model name + packed_module_path = list(module_path) + if packed_module_path and packed_module_path[-1] in reversed_mapping.keys(): + packed_module_path[-1] = reversed_mapping[packed_module_path[-1]] + + candidates = [packed_module_path] + + # DeepSeek-V4 keeps stacked checkpoint names in its load_weights method + # rather than in packed_modules_mapping, so mirror those aliases here. + module_path_str = ".".join(packed_module_path) + deepseek_v4_aliases = { + ".ffn.shared_experts.w1": ".ffn.shared_experts.gate_up_proj", + ".ffn.shared_experts.w3": ".ffn.shared_experts.gate_up_proj", + ".attn.wq_a": ".attn.fused_wqa_wkv", + ".attn.wkv": ".attn.fused_wqa_wkv", + ".compressor.wkv": ".compressor.fused_wkv_wgate", + ".compressor.wgate": ".compressor.fused_wkv_wgate", + } + for old, new in deepseek_v4_aliases.items(): + if old in module_path_str: + candidates.append(module_path_str.replace(old, new, 1).split(".")) + + return candidates + + packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) reversed_mapping = { original_name: fused_name for fused_name, original_names_list in packed_modules_mapping.items() for original_name in original_names_list } - if module_path[-1] in reversed_mapping.keys(): - module_path[-1] = reversed_mapping[module_path[-1]] - current_module = model - try: - # Traverse the model hierarchy - for part in module_path: - if isinstance(current_module, FusedMoE): - return current_module - elif isinstance(current_module, torch.nn.ModuleList): - current_module = current_module[int(part)] - else: - current_module = getattr(current_module, part) - except (AttributeError, IndexError, ValueError) as e: - print(f"Warning: Could not find module for parameter '{name}'. Error: {e}") - return current_module + last_error = None + for module_path in _candidate_module_paths(name): + current_module = model + try: + # Traverse the model hierarchy + for part in module_path: + if isinstance(current_module, FusedMoE): + return current_module + elif isinstance(current_module, torch.nn.ModuleList): + current_module = current_module[int(part)] + else: + current_module = getattr(current_module, part) + return current_module + except (AttributeError, IndexError, ValueError) as e: + last_error = e + logger.debug("Could not find module for parameter %r: %s", name, last_error) + return None def is_fp8_weight(name, model): @@ -96,21 +469,109 @@ def is_fp8_weight(name, model): module = get_module_from_param_name(model, name) # We currently only quantize linear layers - if (isinstance(module, LinearBase) and module.weight.dtype == torch.float8_e4m3fn) or ( - isinstance(module, FusedMoE) - and module.w13_weight.dtype == torch.float8_e4m3fn - and module.w2_weight.dtype == torch.float8_e4m3fn + if module is not None and ( + (isinstance(module, LinearBase) and module.weight.dtype == torch.float8_e4m3fn) + or ( + isinstance(module, FusedMoE) + and module.w13_weight.dtype == torch.float8_e4m3fn + and module.w2_weight.dtype == torch.float8_e4m3fn + ) ): fp8_state.fp8_param_names.add(name) return name in fp8_state.fp8_param_names +def is_mxfp4_moe_weight(name, tensor, model): + if not name.endswith(".weight") or ".experts." not in name: + return False + if _is_prequantized_mxfp4_tensor(tensor): + return False + module = get_module_from_param_name(model, name) + return _is_mxfp4_moe_module(module) + + +def _is_prequantized_mxfp4_tensor(tensor): + float4_dtype = getattr(torch, "float4_e2m1fn_x2", None) + return tensor.dtype in (torch.int8, torch.uint8, float4_dtype) + + +def _is_prequantized_fp8_tensor(tensor): + fp8_dtypes = tuple( + dtype + for dtype in (getattr(torch, "float8_e4m3fn", None), getattr(torch, "float8_e5m2", None)) + if dtype is not None + ) + return tensor.dtype in fp8_dtypes + + +def _model_type(model): + for obj in (model, getattr(model, "config", None), getattr(model, "hf_config", None)): + if obj is None: + continue + model_type = getattr(obj, "model_type", None) + if model_type is not None: + return model_type + config = getattr(model, "config", None) + text_config = getattr(config, "text_config", None) + return getattr(text_config, "model_type", None) + + +def _uses_dot_scale_suffix(model): + return _model_type(model) == "deepseek_v4" + + +def _scale_name_for_weight(name, model, *, is_mxfp8_npu=False, use_scale_not_scale_inv=False, force_scale=False): + if name.endswith(".weight") and _uses_dot_scale_suffix(model): + return name[: -len(".weight")] + ".scale" + if force_scale: + return name + "_scale" + if is_mxfp8_npu: + return name + "_scale" + if use_scale_not_scale_inv and "expert" not in name: + return name + "_scale" + return name + "_scale_inv" + + +def _mxfp4_scale_to_e8m0(scale): + scale = scale.to(torch.float32) + safe_scale = torch.where(scale > 0, scale, torch.ones_like(scale)) + scale_exp = torch.ceil(torch.log2(safe_scale)).to(torch.int32) + _MXFP4_E8M0_BIAS + return scale_exp.clamp(_MXFP4_E8M0_MIN, _MXFP4_E8M0_MAX).to(torch.uint8) + + +def quantize_mxfp4_weight(weight, dtype=torch.bfloat16): + target_dtype = dtype if dtype in (torch.bfloat16, torch.float16) else torch.bfloat16 + weight = weight.to(target_dtype).contiguous() + *prefix_shape, hidden_dim = weight.shape + if hidden_dim % 32 != 0: + raise ValueError(f"MXFP4 weight hidden dimension must be divisible by 32, got shape={tuple(weight.shape)}") + + block_size = 32 + num_blocks = hidden_dim // block_size + blocks = weight.reshape(-1, num_blocks, block_size).to(torch.float32) + + amax = blocks.abs().amax(dim=-1) + quant_scale = _mxfp4_scale_to_e8m0(amax / _MXFP4_E2M1_MAX) + scale = torch.exp2((quant_scale.to(torch.float32) - _MXFP4_E8M0_BIAS).unsqueeze(-1)) + + scaled = (blocks / scale).clamp(-_MXFP4_E2M1_MAX, _MXFP4_E2M1_MAX) + thresholds = torch.tensor(_MXFP4_E2M1_THRESHOLDS, dtype=torch.float32, device=weight.device) + magnitude = torch.bucketize(scaled.abs(), thresholds).to(torch.uint8) + sign = torch.where(scaled < 0, torch.full_like(magnitude, 8), torch.zeros_like(magnitude)) + codes = magnitude | sign + + quant_weight = codes[..., 0::2] | (codes[..., 1::2] * 16) + quant_weight = quant_weight.view(*prefix_shape, hidden_dim // 2) + quant_scale = quant_scale.view(*prefix_shape, num_blocks) + return quant_weight, quant_scale + + def is_mxfp8_vllm_ascend(quant_config): try: from vllm_ascend.quantization.modelslim_config import AscendModelSlimConfig from vllm_ascend.quantization.quant_config import AscendQuantConfig - if isinstance(quant_config, AscendModelSlimConfig) or isinstance(quant_config, AscendQuantConfig): + if isinstance(quant_config, AscendModelSlimConfig | AscendQuantConfig): quant_method = quant_config.quant_description.get("quant_method") return quant_method in ["ascend"] return False @@ -139,12 +600,6 @@ def apply_mxfp8_transformation_after_loading(model): Must be called AFTER model.load_weights() in RL training loops. """ - try: - from vllm.model_executor.layers.linear import LinearBase - except ImportError: - logger.warning("Could not import LinearBase, skipping MXFP8 transformation") - return - for name, module in model.named_modules(): if (isinstance(module, LinearBase) or isinstance(module, FusedMoE)) and hasattr( module, "_mxfp8_original_shapes" @@ -178,12 +633,25 @@ def quant_weights(weights, model, quant_config, dtype=torch.bfloat16): _use_scale_not_scale_inv = version.parse("0.11.0") <= version.parse(vllm.__version__) < version.parse("0.14.0") for k, v in weights: + if is_mxfp4_moe_weight(k, v, model): + if _is_rank_zero(): + logger.debug(f"Quantizing to MXFP4 blockwise: {k}") + param_lp, param_scale = quantize_mxfp4_weight(v, dtype=dtype) + yield (k, param_lp) + yield (_scale_name_for_weight(k, model, force_scale=True), param_scale) + del v, param_lp, param_scale + continue + if not is_fp8_weight(k, model): yield (k, v) continue + if _is_prequantized_fp8_tensor(v): + yield (k, v) + continue + # Cast the weight into fp8 and its scale factor - if torch.distributed.get_rank() == 0: + if _is_rank_zero(): logger.debug(f"Quantizing to FP8 blockwise: {k}") if is_mxfp8_npu: param_lp, param_scale = torch_npu.npu_dynamic_mx_quant( @@ -203,18 +671,21 @@ def quant_weights(weights, model, quant_config, dtype=torch.bfloat16): yield (k, param_lp) # Yield the scale with appropriate naming based on vLLM version - if is_mxfp8_npu: - yield (k + "_scale", param_scale) - elif _use_scale_not_scale_inv and "expert" not in k: - yield (k + "_scale", param_scale) - else: - yield (k + "_scale_inv", param_scale) + yield ( + _scale_name_for_weight( + k, + model, + is_mxfp8_npu=is_mxfp8_npu, + use_scale_not_scale_inv=_use_scale_not_scale_inv, + ), + param_scale, + ) # Explicitly delete original tensor reference to help GC del v, param_lp, param_scale -def load_quanted_weights(weights, model_runner, is_drafter=False): +def _get_quanted_weight_model(model_runner, is_drafter=False): if is_drafter: drafter = getattr(model_runner, "drafter", None) model = drafter.model if drafter is not None and hasattr(drafter, "model") else None @@ -224,10 +695,15 @@ def load_quanted_weights(weights, model_runner, is_drafter=False): ) else: model = model_runner.model + return model + + +def prepare_quanted_weights_for_loading(model_runner, is_drafter=False): + model = _get_quanted_weight_model(model_runner, is_drafter=is_drafter) quant_config = model_runner.vllm_config.quant_config - vllm_dtype = model_runner.vllm_config.model_config.dtype is_mxfp8_npu = is_mxfp8_vllm_ascend(quant_config) + is_mxfp4_moe = _restore_mxfp4_moe_params_for_loading(model) if is_mxfp8_npu: # For MXFP8 on NPU, we need to restore weights to original shapes @@ -236,24 +712,54 @@ def load_quanted_weights(weights, model_runner, is_drafter=False): # but the weight_loader expects original shapes. restore_mxfp8_weights_for_loading(model) + _ensure_model_params_reloadable(model) + return {"is_mxfp8_npu": is_mxfp8_npu, "is_mxfp4_moe": is_mxfp4_moe} + + +def process_quanted_weights_after_loading(model_runner, reload_state=None, is_drafter=False): + model = _get_quanted_weight_model(model_runner, is_drafter=is_drafter) + quant_config = model_runner.vllm_config.quant_config + if reload_state is None: + reload_state = { + "is_mxfp8_npu": is_mxfp8_vllm_ascend(quant_config), + "is_mxfp4_moe": any(_is_mxfp4_moe_module(module) for module in model.modules()), + } + if reload_state.get("is_mxfp8_npu"): + # Re-apply MXFP8 transformations after weight loading. + apply_mxfp8_transformation_after_loading(model) + if reload_state.get("is_mxfp4_moe"): + _process_mxfp4_moe_weights_after_loading(model) + + +def load_quanted_weights(weights, model_runner, is_drafter=False, prepare_model=True, process_model=True): + model = _get_quanted_weight_model(model_runner, is_drafter=is_drafter) + quant_config = model_runner.vllm_config.quant_config + vllm_dtype = model_runner.vllm_config.model_config.dtype + + reload_state = None + if prepare_model: + reload_state = prepare_quanted_weights_for_loading(model_runner, is_drafter=is_drafter) + weights_quantized = quant_weights(weights, model, quant_config, dtype=vllm_dtype) # Monkey patch the param class to their subclass, as certain models # will check the param type to call the proper weightloader - for name, param in model.named_parameters(): + for _, param in model.named_parameters(): if hasattr(param, "subclass_type"): param.orig_type = param.__class__ param.__class__ = param.subclass_type # Finally load the weights into vllm - loaded_params = model.load_weights(weights_quantized) - # Undo the type change above to the original type - for name, param in model.named_parameters(): - if hasattr(param, "subclass_type"): - param.__class__ = param.orig_type + try: + loaded_params = model.load_weights(weights_quantized) + finally: + # Undo the type change above to the original type + for _, param in model.named_parameters(): + if hasattr(param, "orig_type"): + param.__class__ = param.orig_type + del param.orig_type - if is_mxfp8_npu: - # Re-apply MXFP8 transformations after weight loading - apply_mxfp8_transformation_after_loading(model) + if process_model: + process_quanted_weights_after_loading(model_runner, reload_state, is_drafter=is_drafter) return loaded_params @@ -316,29 +822,10 @@ def process_weights_after_loading_for_vllm10(self, layer) -> None: new torch.nn.Parameter objects, because that removes the weight_loader attribute which we need for refit. """ logger.debug("Applying patch process_weights_after_loading") - try: - from vllm.model_executor.parameter import ( - BlockQuantScaleParameter, - ModelWeightParameter, - ) - except Exception: - print("error") - from torch.nn import Parameter - - def _create_param_from_subclass_attributes(custom_param): - param = Parameter(custom_param.data, requires_grad=False) - base_param_dir = dir(torch.nn.Parameter) - custom_param_dir = dir(custom_param) - # Find the attributes that are unique to the custom parameter - custom_attributes = [ - attr for attr in custom_param_dir if attr not in base_param_dir and not attr.startswith("__") - ] - # Set the custom attributes into the base parameter object - for attr in custom_attributes: - setattr(param, attr, getattr(custom_param, attr)) - - param.subclass_type = type(custom_param) - return param + from vllm.model_executor.parameter import ( + BlockQuantScaleParameter, + ModelWeightParameter, + ) assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized assert self.quant_config.activation_scheme == "dynamic" @@ -370,7 +857,6 @@ def process_weights_after_loading_for_vllm11(self, layer) -> None: Compared to the original process_weights_after_loading in vllm, we just avoid creation of new torch.nn.Parameter objects, because that removes the weight_loader attribute which we need for refit. """ - from torch.nn import Parameter from vllm.model_executor.layers.quantization.utils.fp8_utils import ( maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, @@ -383,21 +869,6 @@ def process_weights_after_loading_for_vllm11(self, layer) -> None: assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized assert self.quant_config.activation_scheme == "dynamic" - def _create_param_from_subclass_attributes(custom_param): - param = Parameter(custom_param.data, requires_grad=False) - base_param_dir = dir(torch.nn.Parameter) - custom_param_dir = dir(custom_param) - # Find the attributes that are unique to the custom parameter - custom_attributes = [ - attr for attr in custom_param_dir if attr not in base_param_dir and not attr.startswith("__") - ] - # Set the custom attributes into the base parameter object - for attr in custom_attributes: - setattr(param, attr, getattr(custom_param, attr)) - - param.subclass_type = type(custom_param) - return param - weight_scale = layer.weight_scale_inv if hasattr(layer, "weight_scale_inv") else layer.weight_scale weight, weight_scale = process_fp8_weight_block_strategy(layer.weight, weight_scale) @@ -432,9 +903,7 @@ def process_weights_after_loading_for_vllm14(self, layer) -> None: Compared to the original process_weights_after_loading in vllm, we just avoid creation of new torch.nn.Parameter objects, because that removes the weight_loader attribute which we need for refit. """ - from torch.nn import Parameter from vllm.model_executor.layers.quantization.utils.fp8_utils import ( - maybe_post_process_fp8_weight_block, process_fp8_weight_block_strategy, ) from vllm.model_executor.parameter import ( @@ -445,21 +914,6 @@ def process_weights_after_loading_for_vllm14(self, layer) -> None: assert self.block_quant and self.quant_config.is_checkpoint_fp8_serialized assert self.quant_config.activation_scheme == "dynamic" - def _create_param_from_subclass_attributes(custom_param): - param = Parameter(custom_param.data, requires_grad=False) - base_param_dir = dir(torch.nn.Parameter) - custom_param_dir = dir(custom_param) - # Find the attributes that are unique to the custom parameter - custom_attributes = [ - attr for attr in custom_param_dir if attr not in base_param_dir and not attr.startswith("__") - ] - # Set the custom attributes into the base parameter object - for attr in custom_attributes: - setattr(param, attr, getattr(custom_param, attr)) - - param.subclass_type = type(custom_param) - return param - weight, weight_scale_inv = process_fp8_weight_block_strategy(layer.weight, layer.weight_scale_inv) layer.weight = _create_param_from_subclass_attributes( @@ -485,7 +939,16 @@ def _create_param_from_subclass_attributes(custom_param): if not hasattr(layer, "input_scale"): layer.input_scale = None - maybe_post_process_fp8_weight_block(layer) + try: + from vllm.model_executor.layers.quantization.utils.fp8_utils import maybe_post_process_fp8_weight_block + except ImportError: + maybe_post_process_fp8_weight_block = None + + if maybe_post_process_fp8_weight_block is not None: + maybe_post_process_fp8_weight_block(layer) + elif hasattr(self, "fp8_linear"): + self.fp8_linear.process_weights_after_loading(layer) + _ensure_linear_params_reloadable(layer) def process_weights_after_loading_moe_for_vllm10(self, layer) -> None: @@ -511,28 +974,10 @@ def process_weights_after_loading_moe_for_vllm10(self, layer) -> None: w2_weight = layer.w2_weight w2_weight_scale_inv = layer.w2_weight_scale_inv - from torch.nn import Parameter - - def _create_param_from_subclass_attributes(custom_data, custom_weight): - param = Parameter(custom_data, requires_grad=False) - base_param_dir = dir(torch.nn.Parameter) - custom_weight_dir = dir(custom_weight) - # Find the attributes that are unique to the custom parameter - custom_attributes = [ - attr for attr in custom_weight_dir if attr not in base_param_dir and not attr.startswith("__") - ] - # Set the custom attributes into the base parameter object - for attr in custom_attributes: - setattr(param, attr, getattr(custom_weight, attr)) - - return param - - layer.w13_weight = _create_param_from_subclass_attributes(w13_weight, layer.w13_weight) - layer.w13_weight_scale_inv = _create_param_from_subclass_attributes( - w13_weight_scale_inv, layer.w13_weight_scale_inv - ) - layer.w2_weight = _create_param_from_subclass_attributes(w2_weight, layer.w2_weight) - layer.w2_weight_scale_inv = _create_param_from_subclass_attributes(w2_weight_scale_inv, layer.w2_weight_scale_inv) + layer.w13_weight = _create_param_from_data_with_attrs(w13_weight, layer.w13_weight) + layer.w13_weight_scale_inv = _create_param_from_data_with_attrs(w13_weight_scale_inv, layer.w13_weight_scale_inv) + layer.w2_weight = _create_param_from_data_with_attrs(w2_weight, layer.w2_weight) + layer.w2_weight_scale_inv = _create_param_from_data_with_attrs(w2_weight_scale_inv, layer.w2_weight_scale_inv) # DeepGemm scales need to be transposed and aligned. We try to do # it ahead of time for performance reasons. @@ -648,28 +1093,12 @@ def process_weights_after_loading_moe_for_vllm14(self, layer) -> None: w13_input_scale=w13_input_scale, w2_input_scale=w2_input_scale, ) - from torch.nn import Parameter - - def _create_param_from_subclass_attributes(custom_data, custom_weight): - param = Parameter(custom_data, requires_grad=False) - base_param_dir = dir(torch.nn.Parameter) - custom_weight_dir = dir(custom_weight) - # Find the attributes that are unique to the custom parameter - custom_attributes = [ - attr for attr in custom_weight_dir if attr not in base_param_dir and not attr.startswith("__") - ] - # Set the custom attributes into the base parameter object - for attr in custom_attributes: - setattr(param, attr, getattr(custom_weight, attr)) - - return param - # Replace parameters with updated versions. Note that this helper # function ensures the replacement is compatible with RL weight reloads. - layer.w13_weight = _create_param_from_subclass_attributes(w13, layer.w13_weight) - layer.w2_weight = _create_param_from_subclass_attributes(w2, layer.w2_weight) - layer.w13_weight_scale_inv = _create_param_from_subclass_attributes(w13_scale, layer.w13_weight_scale_inv) - layer.w2_weight_scale_inv = _create_param_from_subclass_attributes(w2_scale, layer.w2_weight_scale_inv) + layer.w13_weight = _create_param_from_data_with_attrs(w13, layer.w13_weight) + layer.w2_weight = _create_param_from_data_with_attrs(w2, layer.w2_weight) + layer.w13_weight_scale_inv = _create_param_from_data_with_attrs(w13_scale, layer.w13_weight_scale_inv) + layer.w2_weight_scale_inv = _create_param_from_data_with_attrs(w2_scale, layer.w2_weight_scale_inv) self.moe_quant_config = self.get_fused_moe_quant_config(layer) if self.moe_quant_config: @@ -699,12 +1128,13 @@ def _create_param_from_subclass_attributes(custom_data, custom_weight): def apply_vllm_fp8_patches(): - logger.info("Applying vllm fp8 patches for blockwise quantization") - vllm_ver = version.parse(vllm.__version__) if fp8_state.vllm_patches: logger.debug("vLLM FP8 patches already applied") return + logger.info("Applying vllm fp8 patches for blockwise quantization") + vllm_ver = version.parse(vllm.__version__) + func1_path = "vllm.model_executor.layers.quantization.fp8.Fp8LinearMethod.process_weights_after_loading" func2_path = "vllm.model_executor.layers.quantization.fp8.Fp8MoEMethod.process_weights_after_loading" diff --git a/verl/workers/config/model.py b/verl/workers/config/model.py index 95814663dca..fad4ad32827 100644 --- a/verl/workers/config/model.py +++ b/verl/workers/config/model.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import os from dataclasses import dataclass, field from typing import Any, Optional @@ -26,6 +27,30 @@ __all__ = ["HFModelConfig", "MtpConfig"] +def _resolve_custom_chat_template(template: Optional[str]) -> Optional[str]: + if not isinstance(template, str): + return template + + if template.startswith("@"): + path = os.path.expanduser(template[1:]) + with open(path, encoding="utf-8") as f: + return f.read() + + if template.startswith("env:"): + env_name = template.removeprefix("env:") + if env_name not in os.environ: + raise ValueError(f"Custom chat template environment variable is not set: {env_name}") + return os.environ[env_name] + + if template.startswith("${oc.env:") and template.endswith("}"): + env_name = template[len("${oc.env:") : -1] + if env_name not in os.environ: + raise ValueError(f"Custom chat template environment variable is not set: {env_name}") + return os.environ[env_name] + + return template + + @dataclass class MtpConfig(BaseConfig): """ @@ -82,6 +107,7 @@ class HFModelConfig(BaseConfig): "architectures", "local_hf_config_path", "local_tokenizer_path", + "custom_chat_template", "mtp", } @@ -170,6 +196,7 @@ def __post_init__(self): ): self.processor.chat_template = self.tokenizer.chat_template + self.custom_chat_template = _resolve_custom_chat_template(self.custom_chat_template) if self.custom_chat_template is not None: if self.processor is not None: self.processor.chat_template = self.custom_chat_template diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index d29d5a591f7..de34ef0a8a4 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -214,6 +214,10 @@ def _build_tf_config(self): } for key, value in override_transformer_config.items(): provider_overrides[key] = value + if not self.model_config.mtp.enable: + provider_overrides["mtp_num_layers"] = 0 + if provider.csa_compress_ratios is not None: + provider_overrides["csa_compress_ratios"] = provider.csa_compress_ratios[: provider.num_layers] if self.enable_routing_replay: provider_overrides["enable_routing_replay"] = True diff --git a/verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py b/verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py index 82256b0beaa..01a1a1fdc34 100644 --- a/verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py +++ b/verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py @@ -41,6 +41,11 @@ class TensorMetadata(TypedDict): handle: tuple +def _align_offset(offset: int, alignment: int) -> int: + alignment = max(alignment, 1) + return ((offset + alignment - 1) // alignment) * alignment + + # copy from https://github.com/vllm-project/vllm/blob/main/examples/offline_inference/rlhf_utils.py def rebuild_ipc(handle: tuple[Callable, tuple], device_id: int | None = None) -> torch.Tensor: func, args = handle @@ -125,15 +130,20 @@ async def async_send_weights(self, weights): # transfer volume. # weight = weight.to(dtype, non_blocking=True) + weight_nbytes = weight.nbytes + alignment = max(weight.element_size(), 1) + aligned_offset = _align_offset(offset, alignment) + # fill the tensor bucket - if offset + weight.nbytes > self.bucket_size and len(bucket_meta) > 0: + if aligned_offset + weight_nbytes > self.bucket_size and len(bucket_meta) > 0: get_torch_device().synchronize() self.socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": False}) self.socket.recv() bucket_meta = {} offset = 0 + aligned_offset = 0 - if offset + weight.nbytes > self.bucket_size: + if aligned_offset + weight_nbytes > self.bucket_size: assert not self.use_shm, ( f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." f"Please increase rollout.update_weights_bucket_megabytes({self.bucket_size_mb} MB)." @@ -141,6 +151,7 @@ async def async_send_weights(self, weights): self._direct_send_large_weight(name, weight) continue + offset = aligned_offset bucket_meta[name] = { "name": name, "shape": weight.shape, @@ -148,8 +159,8 @@ async def async_send_weights(self, weights): "offset": offset, "handle": None, } - self.buffer[offset : offset + weight.nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) - offset += weight.nbytes + self.buffer[offset : offset + weight_nbytes].copy_(weight.view(-1).view(torch.uint8), non_blocking=True) + offset += weight_nbytes # send the last bucket get_torch_device().synchronize() diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index 95a9391ffa6..53fbf588cff 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -122,11 +122,13 @@ def __new__(cls, **kwargs): # 1. patch for Lora VLLMHijack.hijack() - # 2. patch online fp8 quant - if os.environ.get("VERL_VLLM_FP8_QUANT_ENABLED", "0") == "1": + vllm_config = kwargs.get("vllm_config") + # 2. patch online fp8 quant. Some models, including DeepSeek-V4, get + # fp8 from the HF config rather than an explicit rollout quantization arg. + if os.environ.get("VERL_VLLM_FP8_QUANT_ENABLED", "0") == "1" or is_fp8_model(vllm_config): apply_vllm_fp8_patches() + os.environ["VERL_VLLM_FP8_QUANT_ENABLED"] = "1" # 3. patch QAT (compressed-tensors NVFP4) for dynamic weight loading - vllm_config = kwargs.get("vllm_config") quant_config = getattr(vllm_config, "quant_config", None) if vllm_config else None _is_qat_model = getattr(quant_config, "quant_format", None) == "nvfp4-pack-quantized" _is_modelopt_qat = type(quant_config).__name__ == "ModelOptNvFp4Config" @@ -229,6 +231,12 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False patch_vllm_moe_model_weight_loader(model) assert self.device is not None + quant_reload_state = None + if is_fp8_model(self.model_runner.vllm_config) and not (peft_config and base_sync_done): + from verl.utils.vllm.vllm_fp8_utils import prepare_quanted_weights_for_loading + + quant_reload_state = prepare_quanted_weights_for_loading(self.model_runner) + receiver = BucketedWeightReceiver( zmq_handle=self._get_zmq_handle(), device=self.device, @@ -236,7 +244,10 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False ) receiver.receive_weights( on_bucket_received=lambda weights: self._update_weights( - weights, peft_config=peft_config, base_sync_done=base_sync_done + weights, + peft_config=peft_config, + base_sync_done=base_sync_done, + quant_prepared=quant_reload_state is not None, ) ) @@ -252,6 +263,11 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False modelopt_process_weights_after_loading(self.model_runner.model) logger.info("ModelOpt QAT: process_weights_after_loading completed") + elif quant_reload_state is not None: + from verl.utils.vllm.vllm_fp8_utils import process_quanted_weights_after_loading + + process_quanted_weights_after_loading(self.model_runner, quant_reload_state) + logger.info("FP8/MXFP4: process_weights_after_loading completed") elif use_standard_weight_load: # Some post-load transforms are non-idempotent; run once after all buckets. from vllm.model_executor.model_loader.utils import process_weights_after_loading @@ -259,7 +275,13 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False for model, model_config in self._iter_all_models_with_config(): process_weights_after_loading(model, model_config, self.device) - def _update_weights(self, weights: list[tuple[str, torch.Tensor]], peft_config: dict, base_sync_done: bool): + def _update_weights( + self, + weights: list[tuple[str, torch.Tensor]], + peft_config: dict, + base_sync_done: bool, + quant_prepared: bool = False, + ): if peft_config and base_sync_done: weights = dict(weights) lora_request = TensorLoRARequest( @@ -277,7 +299,12 @@ def _update_weights(self, weights: list[tuple[str, torch.Tensor]], peft_config: if is_fp8_model(self.model_runner.vllm_config): logger.info(f"FP8 model detected (async): {self.model_runner.vllm_config.quant_config}") # Convert bf16 weights to fp8 format before loading - loaded_params = load_quanted_weights(weights, self.model_runner) + loaded_params = load_quanted_weights( + weights, + self.model_runner, + prepare_model=not quant_prepared, + process_model=not quant_prepared, + ) logger.info(f"FP8 weights loaded (async), loaded_params: {len(loaded_params)}") # Keep the draft model in sync when present. if self._use_mtp_drafter_weight_sync(): diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index 112dfb2cb87..fe2410e9a28 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -833,6 +833,22 @@ def _apply_quantization(self) -> tuple[Optional[str], dict]: """Process quantization config. Returns (quantization_str, hf_overrides).""" quantization = self.config.quantization hf_overrides = {} + hf_config = self.model_config.hf_config + if not self.model_config.mtp.enable: + for field in ("num_nextn_predict_layers", "mtp_num_hidden_layers"): + if hasattr(hf_config, field): + hf_overrides[field] = 0 + rope_scaling = getattr(hf_config, "rope_scaling", None) + if isinstance(rope_scaling, dict): + rope_type = rope_scaling.get("rope_type") or rope_scaling.get("type") + if rope_type == "yarn": + rope_scaling_override = dict(rope_scaling) + rope_scaling_override["rope_type"] = rope_type + rope_scaling_override["type"] = rope_type + for key in ("factor", "beta_fast", "beta_slow"): + if key in rope_scaling_override: + rope_scaling_override[key] = float(rope_scaling_override[key]) + hf_overrides["rope_scaling"] = rope_scaling_override if is_torch_npu_available(check_device=False): from verl.utils.vllm.npu_vllm_patch import check_vllm_ascend_before_server_launch From 7cbf7d31ecdc1f2693de0d5f35086c547daba364 Mon Sep 17 00:00:00 2001 From: Hollow Man Date: Tue, 26 May 2026 01:07:46 -0700 Subject: [PATCH 2/2] Address review issues Signed-off-by: Hollow Man --- tests/utils/test_vllm_fp8_utils.py | 14 +++++++++ verl/utils/vllm/vllm_fp8_utils.py | 17 ++++++++--- .../engine/megatron/transformer_impl.py | 5 ++-- .../vllm_rollout/bucketed_weight_transfer.py | 30 ++++++++++--------- verl/workers/rollout/vllm_rollout/utils.py | 19 +++++++----- .../rollout/vllm_rollout/vllm_async_server.py | 4 ++- 6 files changed, 61 insertions(+), 28 deletions(-) diff --git a/tests/utils/test_vllm_fp8_utils.py b/tests/utils/test_vllm_fp8_utils.py index 7ae619c4af1..eb8ecf8d946 100644 --- a/tests/utils/test_vllm_fp8_utils.py +++ b/tests/utils/test_vllm_fp8_utils.py @@ -92,6 +92,20 @@ class Model: ) +def test_model_type_handles_missing_model_and_config(): + vllm_fp8_utils = _import_vllm_fp8_utils() + + class ModelWithoutConfig: + pass + + class ModelWithEmptyConfig: + config = None + + assert vllm_fp8_utils._model_type(None) is None + assert vllm_fp8_utils._model_type(ModelWithoutConfig()) is None + assert vllm_fp8_utils._model_type(ModelWithEmptyConfig()) is None + + def test_forced_scale_name_keeps_mxfp4_suffix_for_other_models(): vllm_fp8_utils = _import_vllm_fp8_utils() diff --git a/verl/utils/vllm/vllm_fp8_utils.py b/verl/utils/vllm/vllm_fp8_utils.py index a0a673418a6..81ee215097e 100644 --- a/verl/utils/vllm/vllm_fp8_utils.py +++ b/verl/utils/vllm/vllm_fp8_utils.py @@ -58,7 +58,10 @@ def _copy_param_attributes(dst_param, src_param): base_param_dir = dir(torch.nn.Parameter) for attr in dir(src_param): if attr not in base_param_dir and not attr.startswith("__"): - setattr(dst_param, attr, getattr(src_param, attr)) + try: + setattr(dst_param, attr, getattr(src_param, attr)) + except Exception: + pass def _create_param_from_subclass_attributes(custom_param, source_param=None): @@ -435,7 +438,7 @@ def _candidate_module_paths(param_name: str) -> list[list[str]]: return candidates - packed_modules_mapping = getattr(model, "packed_modules_mapping", {}) + packed_modules_mapping = getattr(model, "packed_modules_mapping", None) or {} reversed_mapping = { original_name: fused_name for fused_name, original_names_list in packed_modules_mapping.items() @@ -505,6 +508,9 @@ def _is_prequantized_fp8_tensor(tensor): def _model_type(model): + if model is None: + return None + for obj in (model, getattr(model, "config", None), getattr(model, "hf_config", None)): if obj is None: continue @@ -512,8 +518,11 @@ def _model_type(model): if model_type is not None: return model_type config = getattr(model, "config", None) - text_config = getattr(config, "text_config", None) - return getattr(text_config, "model_type", None) + if config is not None: + text_config = getattr(config, "text_config", None) + if text_config is not None: + return getattr(text_config, "model_type", None) + return None def _uses_dot_scale_suffix(model): diff --git a/verl/workers/engine/megatron/transformer_impl.py b/verl/workers/engine/megatron/transformer_impl.py index de34ef0a8a4..2078ad22a94 100644 --- a/verl/workers/engine/megatron/transformer_impl.py +++ b/verl/workers/engine/megatron/transformer_impl.py @@ -216,8 +216,9 @@ def _build_tf_config(self): provider_overrides[key] = value if not self.model_config.mtp.enable: provider_overrides["mtp_num_layers"] = 0 - if provider.csa_compress_ratios is not None: - provider_overrides["csa_compress_ratios"] = provider.csa_compress_ratios[: provider.num_layers] + csa_compress_ratios = getattr(provider, "csa_compress_ratios", None) + if csa_compress_ratios is not None: + provider_overrides["csa_compress_ratios"] = csa_compress_ratios[: provider.num_layers] if self.enable_routing_replay: provider_overrides["enable_routing_replay"] = True diff --git a/verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py b/verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py index 01a1a1fdc34..9e404445d2f 100644 --- a/verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py +++ b/verl/workers/rollout/vllm_rollout/bucketed_weight_transfer.py @@ -25,6 +25,7 @@ import torch import zmq +import zmq.asyncio from torch.multiprocessing.reductions import reduce_tensor from verl.utils.device import get_device_id, get_device_name, get_torch_device @@ -100,7 +101,7 @@ def __init__( self.bucket_size = int(bucket_size_mb) << 20 self.use_shm = use_shm - self.zmq_context = zmq.Context.instance() + self.zmq_context = zmq.asyncio.Context.instance() self.socket = None self.buffer = None self.shm = None @@ -116,7 +117,7 @@ async def async_send_weights(self, weights): try: self._init_socket() - self._init_buffer() + await self._init_buffer() # send bucket weights offset = 0 @@ -137,8 +138,7 @@ async def async_send_weights(self, weights): # fill the tensor bucket if aligned_offset + weight_nbytes > self.bucket_size and len(bucket_meta) > 0: get_torch_device().synchronize() - self.socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": False}) - self.socket.recv() + await self._send_bucket(bucket_meta, is_last=False) bucket_meta = {} offset = 0 aligned_offset = 0 @@ -148,7 +148,7 @@ async def async_send_weights(self, weights): f"Weight {name}({weight.shape}, {weight.dtype}) is too large to fit in the bucket." f"Please increase rollout.update_weights_bucket_megabytes({self.bucket_size_mb} MB)." ) - self._direct_send_large_weight(name, weight) + await self._direct_send_large_weight(name, weight) continue offset = aligned_offset @@ -164,8 +164,7 @@ async def async_send_weights(self, weights): # send the last bucket get_torch_device().synchronize() - self.socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": True}) - self.socket.recv() + await self._send_bucket(bucket_meta, is_last=True) finally: self._cleanup() @@ -180,13 +179,13 @@ def _init_socket(self): self.socket = self.zmq_context.socket(zmq.REQ) self.socket.bind(self.zmq_handle) - def _init_buffer(self): + async def _init_buffer(self): """build communication buffer""" buffer, shm = None, None if not self.use_shm: buffer = torch.empty(self.bucket_size, dtype=torch.uint8, device=f"{get_device_name()}:{get_device_id()}") handle = reduce_tensor(buffer) - self.socket.send_pyobj(handle) + await self.socket.send_pyobj(handle) else: import uuid @@ -196,9 +195,9 @@ def _init_buffer(self): buffer = torch.frombuffer(shm.buf, dtype=torch.uint8) comm_metadata = {"name": shm_name, "size": self.bucket_size} - self.socket.send_pyobj(comm_metadata) + await self.socket.send_pyobj(comm_metadata) - self.socket.recv() + await self.socket.recv() self.buffer = buffer self.shm = shm @@ -224,7 +223,11 @@ def _cleanup(self): get_torch_device().ipc_collect() get_torch_device().empty_cache() - def _direct_send_large_weight(self, name: str, weight: torch.Tensor): + async def _send_bucket(self, bucket_meta: dict[str, TensorMetadata], is_last: bool): + await self.socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": is_last}) + await self.socket.recv() + + async def _direct_send_large_weight(self, name: str, weight: torch.Tensor): """Send a weight larger than the bucket size via cuda ipc or share memory.""" logger.debug(f"Direct sending large weight {name}({weight.shape}, {weight.dtype})") # TODO: support fallback to shared memory @@ -237,8 +240,7 @@ def _direct_send_large_weight(self, name: str, weight: torch.Tensor): "offset": 0, "handle": handle, } - self.socket.send_pyobj({"bucket_meta": bucket_meta, "is_last": False}) - self.socket.recv() + await self._send_bucket(bucket_meta, is_last=False) class BucketedWeightReceiver: diff --git a/verl/workers/rollout/vllm_rollout/utils.py b/verl/workers/rollout/vllm_rollout/utils.py index 53fbf588cff..7cb70de4f73 100644 --- a/verl/workers/rollout/vllm_rollout/utils.py +++ b/verl/workers/rollout/vllm_rollout/utils.py @@ -232,10 +232,13 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False assert self.device is not None quant_reload_state = None + drafter_quant_reload_state = None if is_fp8_model(self.model_runner.vllm_config) and not (peft_config and base_sync_done): from verl.utils.vllm.vllm_fp8_utils import prepare_quanted_weights_for_loading quant_reload_state = prepare_quanted_weights_for_loading(self.model_runner) + if self._use_mtp_drafter_weight_sync(): + drafter_quant_reload_state = prepare_quanted_weights_for_loading(self.model_runner, is_drafter=True) receiver = BucketedWeightReceiver( zmq_handle=self._get_zmq_handle(), @@ -267,6 +270,12 @@ def update_weights_from_ipc(self, peft_config: dict = None, base_sync_done=False from verl.utils.vllm.vllm_fp8_utils import process_quanted_weights_after_loading process_quanted_weights_after_loading(self.model_runner, quant_reload_state) + if drafter_quant_reload_state is not None: + process_quanted_weights_after_loading( + self.model_runner, + drafter_quant_reload_state, + is_drafter=True, + ) logger.info("FP8/MXFP4: process_weights_after_loading completed") elif use_standard_weight_load: # Some post-load transforms are non-idempotent; run once after all buckets. @@ -299,16 +308,12 @@ def _update_weights( if is_fp8_model(self.model_runner.vllm_config): logger.info(f"FP8 model detected (async): {self.model_runner.vllm_config.quant_config}") # Convert bf16 weights to fp8 format before loading - loaded_params = load_quanted_weights( - weights, - self.model_runner, - prepare_model=not quant_prepared, - process_model=not quant_prepared, - ) + reload_kwargs = {"prepare_model": not quant_prepared, "process_model": not quant_prepared} + loaded_params = load_quanted_weights(weights, self.model_runner, **reload_kwargs) logger.info(f"FP8 weights loaded (async), loaded_params: {len(loaded_params)}") # Keep the draft model in sync when present. if self._use_mtp_drafter_weight_sync(): - load_quanted_weights(weights, self.model_runner, is_drafter=True) + load_quanted_weights(weights, self.model_runner, is_drafter=True, **reload_kwargs) else: logger.info("Loading standard weights (non-FP8, async)") for model in self._iter_all_models(): diff --git a/verl/workers/rollout/vllm_rollout/vllm_async_server.py b/verl/workers/rollout/vllm_rollout/vllm_async_server.py index fe2410e9a28..b579e65d78f 100644 --- a/verl/workers/rollout/vllm_rollout/vllm_async_server.py +++ b/verl/workers/rollout/vllm_rollout/vllm_async_server.py @@ -839,6 +839,8 @@ def _apply_quantization(self) -> tuple[Optional[str], dict]: if hasattr(hf_config, field): hf_overrides[field] = 0 rope_scaling = getattr(hf_config, "rope_scaling", None) + if rope_scaling is not None and hasattr(rope_scaling, "to_dict"): + rope_scaling = rope_scaling.to_dict() if isinstance(rope_scaling, dict): rope_type = rope_scaling.get("rope_type") or rope_scaling.get("type") if rope_type == "yarn": @@ -846,7 +848,7 @@ def _apply_quantization(self) -> tuple[Optional[str], dict]: rope_scaling_override["rope_type"] = rope_type rope_scaling_override["type"] = rope_type for key in ("factor", "beta_fast", "beta_slow"): - if key in rope_scaling_override: + if rope_scaling_override.get(key) is not None: rope_scaling_override[key] = float(rope_scaling_override[key]) hf_overrides["rope_scaling"] = rope_scaling_override