From 7a8dd02c5be42a6f570d5236f31b0a1cb64f765b Mon Sep 17 00:00:00 2001 From: mikeq Date: Wed, 27 May 2026 19:04:01 +0800 Subject: [PATCH 1/5] feat(veomni/config): Add new configuration items for the veomni training engine --- verl/trainer/config/engine/veomni.yaml | 31 ++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/verl/trainer/config/engine/veomni.yaml b/verl/trainer/config/engine/veomni.yaml index 492b2f506bc..4b48dc2fc4a 100644 --- a/verl/trainer/config/engine/veomni.yaml +++ b/verl/trainer/config/engine/veomni.yaml @@ -7,6 +7,15 @@ param_offload: False # Whether to offload optimizer state to CPU optimizer_offload: False +# Whether to offload gradients to CPU +grad_offload: false + +# Only for FSDP2: offload param/grad/optimizer during train +offload_policy: false + +# policy for wrapping the model +wrap_policy: {} + # FSDP group size. -1 means use all available GPUs. fsdp_size: -1 @@ -16,6 +25,9 @@ expert_parallel_size: 1 mixed_precision: true +# Mixed precision training param dtype +dtype: bfloat16 + # Random seed for reproducibility. seed: 42 @@ -28,15 +40,31 @@ enable_full_shard: true ckpt_manager: dcp +# Path to load checkpoint from, if any +load_checkpoint_path: null + # Only for FSDP1: FSDP1 configuration, prefetch the next forward-pass all-gather # before the current forward computation. forward_prefetch: true +# Reshard after forward pass to reduce memory footprint +# For FSDP1, `false` enables `ShardingStrategy.SHARD_GRAD_OP` +reshard_after_forward: true + +# Whether to use original parameters in fsdp. Only available in fsdp1 +use_orig_params: false + strategy: veomni # Whether to use torch compile in fsdp. use_torch_compile: false +# Whether to use entropy_from_logits_with_chunking in fsdp. +entropy_from_logits_with_chunking: false + +# Whether to use entropy checkpointing in fsdp. +entropy_checkpointing: false + # Whether to use forward only in fsdp. forward_only: false @@ -66,6 +94,9 @@ force_use_huggingface: false activation_gpu_limit: 0.0 +# List of basic modules to use +basic_modules: [] + # MoE expert-load monitor interval. When > 0, attach VeOmni's MoERouterMonitor. # Scalar metrics flow through Tracking; heatmap images go to wandb on rank 0. moe_load_balance_monitor_interval: 0 From 7c98eb81fce8d5a7c49ce75f6fa044211e4c79ae Mon Sep 17 00:00:00 2001 From: mikeq Date: Wed, 27 May 2026 19:31:40 +0800 Subject: [PATCH 2/5] feat(veomni/cfg): add 3 new implementation config params for Qwen3.5 components in veomni ops_implementation --- verl/trainer/config/engine/veomni.yaml | 3 +++ verl/workers/config/engine.py | 19 +++++++++++++++++++ .../workers/engine/veomni/transformer_impl.py | 3 +++ 3 files changed, 25 insertions(+) diff --git a/verl/trainer/config/engine/veomni.yaml b/verl/trainer/config/engine/veomni.yaml index 4b48dc2fc4a..af858cc9413 100644 --- a/verl/trainer/config/engine/veomni.yaml +++ b/verl/trainer/config/engine/veomni.yaml @@ -88,6 +88,9 @@ cross_entropy_loss_implementation: eager rms_norm_implementation: eager swiglu_mlp_implementation: eager rotary_pos_emb_implementation: eager +rms_norm_gated_implementation: eager +causal_conv1d_implementation: eager +chunk_gated_delta_rule_implementation: eager load_balancing_loss_implementation: eager force_use_huggingface: false diff --git a/verl/workers/config/engine.py b/verl/workers/config/engine.py index b5ad8e9cfea..f51ae9f45ae 100644 --- a/verl/workers/config/engine.py +++ b/verl/workers/config/engine.py @@ -330,6 +330,22 @@ class VeOmniEngineConfig(EngineConfig): in distributed training. Important: this will negatively impact performance, so only use it for debugging. mixed_precision (Optional[dict[str, Any]]): Mixed precision configuration for FSDP, default None + rms_norm_gated_implementation (str): Gated RMSNorm implementation (Qwen3.5 GatedDeltaNet + ``self.norm``). ``"fla"`` uses fla.modules.FusedRMSNormGated (requires flash-linear-attention, + GPU). ``"eager"`` (default) uses the HuggingFace Qwen3_5RMSNormGated. Qwen3.5 has no NPU + backend today — selecting any non-eager value on NPU raises at OpSlot bind time. + causal_conv1d_implementation (str): Varlen depthwise causal conv1d implementation (Qwen3.5 + GatedDeltaNet pre-mixer). ``"fla"`` uses fla.modules.convolution.causal_conv1d (requires + flash-linear-attention, GPU). ``"eager"`` (default) leaves causal_conv1d_fn unset; the varlen + training path then raises because no torch fallback handles cu_seqlens. Qwen3.5 has no NPU + backend today — selecting any non-eager value on NPU raises at OpSlot bind time. + chunk_gated_delta_rule_implementation (str): Chunk gated delta-rule kernel for Qwen3.5 linear + attention. ``"fla"`` uses fla.ops.gated_delta_rule.chunk_gated_delta_rule (requires + flash-linear-attention, GPU). ``"flash_qla"`` uses QwenLM FlashQLA (requires the optional + flash-qla extra, Hopper SM90 only — no Ampere/Ada below or Blackwell above; SM10x wheels are + WIP upstream). ``"eager"`` (default) uses transformers' torch_chunk_gated_delta_rule, which + does NOT support cu_seqlens; varlen training therefore raises at runtime. Qwen3.5 has no NPU + backend today — selecting any non-eager value on NPU raises at OpSlot bind time. """ @@ -367,6 +383,9 @@ class VeOmniEngineConfig(EngineConfig): swiglu_mlp_implementation: str = "eager" rotary_pos_emb_implementation: str = "eager" load_balancing_loss_implementation: str = "eager" + rms_norm_gated_implementation: str = "eager" + causal_conv1d_implementation: str = "eager" + chunk_gated_delta_rule_implementation: str = "eager" force_use_huggingface: bool = False activation_gpu_limit: float = 0.0 basic_modules: Optional[list[str]] = field(default_factory=list) diff --git a/verl/workers/engine/veomni/transformer_impl.py b/verl/workers/engine/veomni/transformer_impl.py index 518565509e8..4f76b2fab15 100644 --- a/verl/workers/engine/veomni/transformer_impl.py +++ b/verl/workers/engine/veomni/transformer_impl.py @@ -277,6 +277,9 @@ def _build_model_optimizer(self): swiglu_mlp_implementation=self.engine_config.swiglu_mlp_implementation, rotary_pos_emb_implementation=self.engine_config.rotary_pos_emb_implementation, load_balancing_loss_implementation=self.engine_config.load_balancing_loss_implementation, + rms_norm_gated_implementation=self.engine_config.rms_norm_gated_implementation, + causal_conv1d_implementation=self.engine_config.causal_conv1d_implementation, + chunk_gated_delta_rule_implementation=self.engine_config.chunk_gated_delta_rule_implementation, ) # Load base model with specified configuration and dtype From 1dee19287ba3c009952e590cb73fe51b72d4137b Mon Sep 17 00:00:00 2001 From: mikeq Date: Wed, 27 May 2026 19:39:15 +0800 Subject: [PATCH 3/5] chore(config): Remove redundant configuration items of the veomni training engine as suggested by Gemini --- verl/trainer/config/engine/veomni.yaml | 16 ---------------- 1 file changed, 16 deletions(-) diff --git a/verl/trainer/config/engine/veomni.yaml b/verl/trainer/config/engine/veomni.yaml index af858cc9413..24d0a60e613 100644 --- a/verl/trainer/config/engine/veomni.yaml +++ b/verl/trainer/config/engine/veomni.yaml @@ -7,15 +7,6 @@ param_offload: False # Whether to offload optimizer state to CPU optimizer_offload: False -# Whether to offload gradients to CPU -grad_offload: false - -# Only for FSDP2: offload param/grad/optimizer during train -offload_policy: false - -# policy for wrapping the model -wrap_policy: {} - # FSDP group size. -1 means use all available GPUs. fsdp_size: -1 @@ -47,13 +38,6 @@ load_checkpoint_path: null # before the current forward computation. forward_prefetch: true -# Reshard after forward pass to reduce memory footprint -# For FSDP1, `false` enables `ShardingStrategy.SHARD_GRAD_OP` -reshard_after_forward: true - -# Whether to use original parameters in fsdp. Only available in fsdp1 -use_orig_params: false - strategy: veomni # Whether to use torch compile in fsdp. From 18f03ab6b66d64df0bc6ec9b7d827ea138cd7df6 Mon Sep 17 00:00:00 2001 From: mikequan0425 Date: Wed, 27 May 2026 21:02:47 +0800 Subject: [PATCH 4/5] refactor(config): Clean up redundant items in the veomni engine configuration based on reviews, as the latest version of veomni no longer supports fsdp1. --- verl/trainer/config/engine/veomni.yaml | 3 --- verl/workers/config/engine.py | 7 ------- 2 files changed, 10 deletions(-) diff --git a/verl/trainer/config/engine/veomni.yaml b/verl/trainer/config/engine/veomni.yaml index 24d0a60e613..da676b99d85 100644 --- a/verl/trainer/config/engine/veomni.yaml +++ b/verl/trainer/config/engine/veomni.yaml @@ -16,9 +16,6 @@ expert_parallel_size: 1 mixed_precision: true -# Mixed precision training param dtype -dtype: bfloat16 - # Random seed for reproducibility. seed: 42 diff --git a/verl/workers/config/engine.py b/verl/workers/config/engine.py index f51ae9f45ae..6c9209b26a6 100644 --- a/verl/workers/config/engine.py +++ b/verl/workers/config/engine.py @@ -272,11 +272,9 @@ class VeOmniEngineConfig(EngineConfig): The inheritance from BaseConfig provides omegaconf.DictConfig-like interface for a dataclass config. Args: - wrap_policy (Dict[str, Any]): Configuration for FSDP wrap policy. param_offload (bool): Whether to offload parameters to CPU, default False optimizer_offload (bool): Whether to offload optimizer states to CPU, default False offload_policy (bool): Whether to offload policy model parameters, default False - reshard_after_forward (bool): Whether to reshard parameters after forward pass, default True fsdp_size (int): FSDP group size. -1 means use all available GPUs, default -1 ulysses_parallel_size (int): Ulysses sequence parallel size, default 1 expert_parallel_size (int): Expert parallel size, default 1 @@ -324,7 +322,6 @@ class VeOmniEngineConfig(EngineConfig): basic_modules (list[str]): List of basic modules to use, default None forward_prefetch (bool): Whether to prefetch parameters for next forward pass, default False model_dtype (str): Model data type used to initialize the transformers model. default "fp32" - use_orig_params (bool): Whether to use original parameters when initialize FSDP1, default False seed (int): Random seed for reproducibility. full_determinism (bool): If true, enable_full_determinism is called to ensure reproducible results in distributed training. Important: this will negatively impact performance, so only use it for @@ -351,11 +348,7 @@ class VeOmniEngineConfig(EngineConfig): _mutable_fields = EngineConfig._mutable_fields | {"attn_implementation"} - wrap_policy: dict[str, Any] = field(default_factory=dict) - offload_policy: bool = False - reshard_after_forward: bool = True forward_prefetch: bool = False - use_orig_params: bool = False entropy_from_logits_with_chunking: bool = False use_torch_compile: bool = True entropy_checkpointing: bool = False From ae430cfa85fa667b3792c4c471f2db631c217472 Mon Sep 17 00:00:00 2001 From: mikeq Date: Thu, 28 May 2026 11:53:58 +0800 Subject: [PATCH 5/5] fix pre-commit --- .../config/_generated_ppo_veomni_trainer.yaml | 21 +++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml index bacf0617ed4..4577be8c9d7 100644 --- a/verl/trainer/config/_generated_ppo_veomni_trainer.yaml +++ b/verl/trainer/config/_generated_ppo_veomni_trainer.yaml @@ -35,9 +35,12 @@ actor_rollout_ref: init_device: meta enable_full_shard: true ckpt_manager: dcp + load_checkpoint_path: null forward_prefetch: true strategy: veomni use_torch_compile: false + entropy_from_logits_with_chunking: false + entropy_checkpointing: false forward_only: false enable_fsdp_offload: false enable_reentrant: false @@ -47,9 +50,13 @@ actor_rollout_ref: rms_norm_implementation: eager swiglu_mlp_implementation: eager rotary_pos_emb_implementation: eager + rms_norm_gated_implementation: eager + causal_conv1d_implementation: eager + chunk_gated_delta_rule_implementation: eager load_balancing_loss_implementation: eager force_use_huggingface: false activation_gpu_limit: 0.0 + basic_modules: [] moe_load_balance_monitor_interval: 0 router_replay: _target_: verl.workers.config.EngineRouterReplayConfig @@ -205,9 +212,12 @@ actor_rollout_ref: init_device: meta enable_full_shard: true ckpt_manager: dcp + load_checkpoint_path: null forward_prefetch: true strategy: veomni use_torch_compile: false + entropy_from_logits_with_chunking: false + entropy_checkpointing: false forward_only: true enable_fsdp_offload: false enable_reentrant: false @@ -217,9 +227,13 @@ actor_rollout_ref: rms_norm_implementation: ${oc.select:actor_rollout_ref.actor.veomni.rms_norm_implementation,eager} swiglu_mlp_implementation: ${oc.select:actor_rollout_ref.actor.veomni.swiglu_mlp_implementation,eager} rotary_pos_emb_implementation: ${oc.select:actor_rollout_ref.actor.veomni.rotary_pos_emb_implementation,eager} + rms_norm_gated_implementation: eager + causal_conv1d_implementation: eager + chunk_gated_delta_rule_implementation: eager load_balancing_loss_implementation: ${oc.select:actor_rollout_ref.actor.veomni.load_balancing_loss_implementation,eager} force_use_huggingface: false activation_gpu_limit: 0.0 + basic_modules: [] moe_load_balance_monitor_interval: 0 router_replay: _target_: verl.workers.config.EngineRouterReplayConfig @@ -492,9 +506,12 @@ critic: init_device: meta enable_full_shard: true ckpt_manager: dcp + load_checkpoint_path: null forward_prefetch: true strategy: veomni use_torch_compile: false + entropy_from_logits_with_chunking: false + entropy_checkpointing: false forward_only: false enable_fsdp_offload: false enable_reentrant: false @@ -504,9 +521,13 @@ critic: rms_norm_implementation: eager swiglu_mlp_implementation: eager rotary_pos_emb_implementation: eager + rms_norm_gated_implementation: eager + causal_conv1d_implementation: eager + chunk_gated_delta_rule_implementation: eager load_balancing_loss_implementation: eager force_use_huggingface: false activation_gpu_limit: 0.0 + basic_modules: [] moe_load_balance_monitor_interval: 0 router_replay: _target_: verl.workers.config.EngineRouterReplayConfig