diff --git a/tests/models/test_transformers_ulysses.py b/tests/models/test_transformers_ulysses.py index c3e6eef9021..a3c7f579b0f 100644 --- a/tests/models/test_transformers_ulysses.py +++ b/tests/models/test_transformers_ulysses.py @@ -29,12 +29,12 @@ from verl.utils.distributed import initialize_global_process_group from verl.utils.model import compute_position_id_with_mask, create_random_mask from verl.utils.ulysses import ( - FSDPUlyssesShardingManager, gather_outputs_and_unpad, get_ulysses_sequence_parallel_world_size, set_ulysses_sequence_parallel_group, ulysses_pad_and_slice_inputs, ) +from verl.workers.sharding_manager.fsdp_ulysses import FSDPUlyssesShardingManager if get_device_name() == "cuda": from flash_attn.bert_padding import index_first_axis, rearrange, unpad_input @@ -50,6 +50,7 @@ class SequenceParallelConfig: config: PretrainedConfig sp_size: int is_valid: bool + attn_implementation: str = "flash_attention_2" def test_configs(): @@ -86,6 +87,20 @@ def test_configs(): ) ) + try: + from transformers import GlmMoeDsaConfig + + configs.append( + SequenceParallelConfig( + GlmMoeDsaConfig(num_hidden_layers=2, n_routed_experts=4, vocab_size=1024, index_topk=128), + sp_size=4, + is_valid=True, + attn_implementation="sdpa", + ) + ) + except ImportError: + pass + return configs @@ -103,13 +118,18 @@ def test_hf_casual_fwd_bwd(test_config): context = contextlib.nullcontext() if test_config.is_valid else pytest.raises(AssertionError) with context: world_size = torch.distributed.get_world_size() - _hf_casual_fwd_bwd(test_config.config, test_config.sp_size, world_size // test_config.sp_size) + _hf_casual_fwd_bwd( + test_config.config, + test_config.sp_size, + world_size // test_config.sp_size, + attn_implementation=test_config.attn_implementation, + ) # TODO: seems not work, will cause `socketStartConnect: Connect to xxx failed : Software caused connection abort` # torch.distributed.destroy_process_group() -def _hf_casual_fwd(config, sp_size, dp_size): +def _hf_casual_fwd(config, sp_size, dp_size, attn_implementation="flash_attention_2"): assert get_torch_device().device_count() >= 2, "need at least 2 gpus for test" ulysses_device_mesh = init_device_mesh( @@ -124,7 +144,7 @@ def _hf_casual_fwd(config, sp_size, dp_size): # patch before load with torch.device(get_device_name()): model = AutoModelForCausalLM.from_config( - config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + config=config, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation ) apply_monkey_patch(model, sp_size) model = model.to(device=get_device_name()) @@ -188,7 +208,7 @@ def _hf_casual_fwd(config, sp_size, dp_size): torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=1e-5) -def _hf_casual_fwd_bwd(config, sp_size, dp_size): +def _hf_casual_fwd_bwd(config, sp_size, dp_size, attn_implementation="flash_attention_2"): assert get_torch_device().device_count() >= 2, "need at least 2 gpus for test" ulysses_device_mesh = init_device_mesh( @@ -203,7 +223,7 @@ def _hf_casual_fwd_bwd(config, sp_size, dp_size): # patch before load with torch.device(get_device_name()): model = AutoModelForCausalLM.from_config( - config=config, torch_dtype=torch.bfloat16, attn_implementation="flash_attention_2" + config=config, torch_dtype=torch.bfloat16, attn_implementation=attn_implementation ) apply_monkey_patch(model, sp_size) model = model.to(device=get_device_name()) @@ -271,13 +291,17 @@ def _hf_casual_fwd_bwd(config, sp_size, dp_size): mean_full.backward() mean_local.backward() - # 3. check the gradients - grad = model.model.layers[0].self_attn.q_proj.weight.grad - grad_full = model_no_sp.model.layers[0].self_attn.q_proj.weight.grad + # 3. check the gradients (MLA models use q_a_proj instead of q_proj) + sp_attn = model.model.layers[0].self_attn + no_sp_attn = model_no_sp.model.layers[0].self_attn + q_proj_name = "q_proj" if hasattr(sp_attn, "q_proj") else "q_a_proj" + grad = getattr(sp_attn, q_proj_name).weight.grad + grad_full = getattr(no_sp_attn, q_proj_name).weight.grad + torch.testing.assert_close(mean_local, mean_full, rtol=1e-2, atol=3e-5) # The check should be less strict because the gradient is not an averaged value. torch.testing.assert_close(grad, grad_full, rtol=1e-2, atol=1e-3) if __name__ == "__main__": - pytest.main([__file__, "-svv"]) + pytest.main([__file__, "-svv", "-c", "/dev/null"]) diff --git a/verl/models/transformers/glm_moe_dsa.py b/verl/models/transformers/glm_moe_dsa.py new file mode 100644 index 00000000000..3ddeb468b09 --- /dev/null +++ b/verl/models/transformers/glm_moe_dsa.py @@ -0,0 +1,273 @@ +# Copyright 2024 Bytedance Ltd. and/or its affiliates +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +import torch +import torch.nn.functional as F +from transformers.cache_utils import Cache + +from verl.utils.ulysses import ( + gather_heads_scatter_seq, + gather_seq_scatter_heads, + get_ulysses_sequence_parallel_world_size, +) + + +def _build_causal_mask_from_position_ids(position_ids: torch.Tensor, seq_len: int, dtype: torch.dtype): + """Build a 4D causal mask [B, 1, S, S] from position_ids that respects sequence boundaries. + + In packed sequences, position_ids resets at sequence boundaries: + e.g., [0, 1, 2, 0, 1, 2, 3, 0, 1] for 3 packed sequences. + + The mask ensures: + - Causal: position i can only attend to positions j where j <= i + - Cross-sequence isolation: position i can only attend to positions in the same sequence + """ + batch_size = position_ids.shape[0] + device = position_ids.device + + boundaries = torch.zeros(batch_size, seq_len, device=device, dtype=torch.long) + if seq_len > 1: + boundaries[:, 1:] = (position_ids[:, 1:] <= position_ids[:, :-1]).long() + segment_ids = boundaries.cumsum(dim=1) # [B, S] + + same_segment = segment_ids.unsqueeze(2) == segment_ids.unsqueeze(1) # [B, S, S] + + indices = torch.arange(seq_len, device=device) + causal = indices.unsqueeze(0) <= indices.unsqueeze(1) # [S, S], causal[i, j] = (j <= i) + + mask = same_segment & causal.unsqueeze(0) # [B, S, S] + + float_mask = torch.where(mask, torch.zeros(1, device=device, dtype=dtype), torch.finfo(dtype).min) + return float_mask.unsqueeze(1) # [B, 1, S, S] + + +def glm_moe_dsa_attn_forward_with_dsa( + self, + hidden_states: torch.Tensor, + position_embeddings: tuple[torch.Tensor, torch.Tensor], + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[Cache] = None, + prev_topk_indices: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + **kwargs, +): + """Drop-in replacement for GlmMoeDsaAttention.forward with use_remove_padding + Ulysses SP. + + When attention_mask is None (use_remove_padding mode), generates a causal mask + from position_ids that respects packed sequence boundaries. + DSA indexer is fully preserved. + + When Ulysses SP > 1, uses all-to-all to gather full sequence / scatter heads + for attention, and all-gathers hidden_states for the indexer. + """ + from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS + + batch_size, seq_length = hidden_states.shape[:-1] + cos, sin = position_embeddings + ulysses_sp_size = get_ulysses_sequence_parallel_world_size() + + # Always rebuild a float additive mask from position_ids. + # The model's create_causal_mask may return a bool mask (incompatible with the DSA indexer + # which expects float 0/-inf) or a local-sized mask (wrong when ulysses_sp_size > 1). + if position_ids is not None and position_ids.dim() == 1: + position_ids = position_ids.unsqueeze(0) + + if ulysses_sp_size > 1 and position_ids is not None: + from verl.utils.ulysses import get_ulysses_sequence_parallel_group + + sp_group = get_ulysses_sequence_parallel_group() + position_ids_full = _all_gather_seq(position_ids, sp_group, ulysses_sp_size) + full_seq_length = seq_length * ulysses_sp_size + attention_mask = _build_causal_mask_from_position_ids(position_ids_full, full_seq_length, hidden_states.dtype) + elif position_ids is not None: + attention_mask = _build_causal_mask_from_position_ids(position_ids, seq_length, hidden_states.dtype) + + # ===== Query path (MLA) ===== + if self.q_lora_rank is None: + query_states = self.q_proj(hidden_states) + q_resid = None + else: + q_resid = self.q_a_layernorm(self.q_a_proj(hidden_states)) + query_states = self.q_b_proj(q_resid) + query_states = query_states.view(batch_size, seq_length, -1, self.qk_head_dim).transpose(1, 2) + q_nope, q_pe = torch.split(query_states, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1) + q_pe = _apply_rotary_pos_emb(q_pe, cos, sin, unsqueeze_dim=1) + + # ===== KV path (MLA compressed) ===== + compressed_kv = self.kv_a_proj_with_mqa(hidden_states) + k_compressed, k_pe = torch.split(compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) + k_compressed = self.kv_a_layernorm(k_compressed) + + kv_expanded = self.kv_b_proj(k_compressed) + kv_expanded = kv_expanded.view(batch_size, seq_length, -1, self.qk_nope_head_dim + self.v_head_dim) + k_nope, value_states = torch.split(kv_expanded, [self.qk_nope_head_dim, self.v_head_dim], dim=-1) + k_nope = k_nope.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + k_pe = k_pe.view(batch_size, 1, seq_length, self.qk_rope_head_dim) + k_pe = _apply_rotary_pos_emb(k_pe, cos, sin, unsqueeze_dim=1) + k_pe = k_pe.expand(-1, k_nope.shape[1], -1, -1) + + # Assemble full Q and K + query_states = torch.cat([q_nope, q_pe], dim=-1) + key_states = torch.cat([k_nope, k_pe], dim=-1) + + # Cache update + if past_key_values is not None: + key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx) + + # ===== Ulysses SP: all-to-all to gather full seq, scatter heads ===== + if ulysses_sp_size > 1: + # [B, H, S_local, D] -> [B, H/sp, S_full, D] + query_states = gather_seq_scatter_heads(query_states, seq_dim=2, head_dim=1) + key_states = gather_seq_scatter_heads(key_states, seq_dim=2, head_dim=1) + value_states = gather_seq_scatter_heads(value_states, seq_dim=2, head_dim=1) + full_seq_length = query_states.shape[2] + else: + full_seq_length = seq_length + + # ===== Indexer (DSA sparse mask) ===== + if not self.skip_topk or prev_topk_indices is None: + if ulysses_sp_size > 1: + # Indexer needs full-sequence hidden_states and q_resid + from verl.utils.ulysses import get_ulysses_sequence_parallel_group + + sp_group = get_ulysses_sequence_parallel_group() + # All-gather hidden_states along seq dim for indexer + hidden_states_full = _all_gather_seq(hidden_states, sp_group, ulysses_sp_size) + q_resid_full = _all_gather_seq(q_resid, sp_group, ulysses_sp_size) if q_resid is not None else None + # Position embeddings also need full seq + cos_full = _all_gather_seq(cos, sp_group, ulysses_sp_size) + sin_full = _all_gather_seq(sin, sp_group, ulysses_sp_size) + position_embeddings_full = (cos_full, sin_full) + else: + hidden_states_full = hidden_states + q_resid_full = q_resid + position_embeddings_full = position_embeddings + + indexer_mask = ( + attention_mask[:, 0, :, :] + if attention_mask is not None and attention_mask.dim() == 4 + else attention_mask.unsqueeze(1) + if attention_mask is not None + else None + ) + topk_indices = self.indexer( + hidden_states_full, + q_resid_full, + position_embeddings_full, + indexer_mask, + use_cache=past_key_values is not None, + ) + else: + topk_indices = prev_topk_indices + + # Build combined DSA + causal mask + total_len = key_states.shape[2] + # topk_indices is [B, S_full, topk] + index_mask = torch.full( + (batch_size, full_seq_length, total_len), + float("-inf"), + device=hidden_states.device, + dtype=query_states.dtype, + ) + index_mask.scatter_(-1, topk_indices, 0.0) + index_mask = index_mask.unsqueeze(1) # [B, 1, S_full, T] + if attention_mask is not None and attention_mask.dim() == 4: + causal_mask = attention_mask[..., :total_len] + combined_mask = index_mask + causal_mask + else: + if attention_mask is not None: + if attention_mask.dim() == 2: + attention_mask = attention_mask.unsqueeze(1).unsqueeze(2) + elif attention_mask.dim() == 3: + attention_mask = attention_mask.unsqueeze(1) + combined_mask = index_mask + attention_mask + else: + combined_mask = index_mask + + # DSA produces a dense 4D float mask incompatible with flash attention. + attn_impl = "sdpa" + + # SDPA requires Q/K/V to have same head_dim; pad V if needed + if self.qk_head_dim != self.v_head_dim: + value_states = F.pad(value_states, [0, self.qk_head_dim - self.v_head_dim]) + + attention_interface = ALL_ATTENTION_FUNCTIONS.get_interface(attn_impl, _eager_attention_forward) + + attn_output, attn_weights = attention_interface( + self, + query_states, + key_states, + value_states, + combined_mask, + dropout=0.0 if not self.training else self.attention_dropout, + scaling=self.scaling, + **kwargs, + ) + + if self.qk_head_dim != self.v_head_dim: + attn_output = attn_output[:, :, :, : self.v_head_dim] + + # ===== Ulysses SP: all-to-all back to scatter seq, gather heads ===== + # attention_interface returns [B, S_full, H/sp, v_head_dim] (already transposed internally) + if ulysses_sp_size > 1: + attn_output = gather_heads_scatter_seq(attn_output, seq_dim=1, head_dim=2) + attn_output = attn_output.reshape(batch_size, seq_length, -1).contiguous() + else: + attn_output = attn_output.reshape(batch_size, full_seq_length, -1).contiguous() + + attn_output = self.o_proj(attn_output) + + topk_for_next = topk_indices if self.next_skip_topk else None + return attn_output, attn_weights, topk_for_next + + +def _all_gather_seq(tensor: torch.Tensor, group, sp_size: int) -> torch.Tensor: + """All-gather tensor along sequence dimension (dim=1).""" + if sp_size <= 1: + return tensor + gathered = [torch.empty_like(tensor) for _ in range(sp_size)] + torch.distributed.all_gather(gathered, tensor.contiguous(), group=group) + return torch.cat(gathered, dim=1) + + +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + x1 = x[..., : x.shape[-1] // 2] + x2 = x[..., x.shape[-1] // 2 :] + return torch.cat((-x2, x1), dim=-1) + + +def _apply_rotary_pos_emb(x, cos, sin, unsqueeze_dim=1): + cos = cos.unsqueeze(unsqueeze_dim) + sin = sin.unsqueeze(unsqueeze_dim) + return (x * cos) + (_rotate_half(x) * sin) + + +def _is_flash_attention_requested(config): + try: + from transformers.utils.generic import is_flash_attention_requested + + return is_flash_attention_requested(config) + except ImportError: + return getattr(config, "_attn_implementation", None) in ("flash_attention_2", "flash_attention_3") + + +def _eager_attention_forward(module, query, key, value, attention_mask, **kwargs): + """Fallback eager attention (imported at runtime to avoid circular deps).""" + from transformers.models.glm_moe_dsa.modeling_glm_moe_dsa import eager_attention_forward + + return eager_attention_forward(module, query, key, value, attention_mask, **kwargs) diff --git a/verl/models/transformers/monkey_patch.py b/verl/models/transformers/monkey_patch.py index 1d9e8ce0587..58eddd37c83 100644 --- a/verl/models/transformers/monkey_patch.py +++ b/verl/models/transformers/monkey_patch.py @@ -474,6 +474,18 @@ def state_dict(self, *args, **kwargs): if ulysses_sp_size > 1: patch_vlm_for_ulysses_input_slicing(Glm4vTextModel) + elif model.config.model_type == "glm_moe_dsa": + if use_fused_kernels: + print("Fused kernels are not supported for glm_moe_dsa. Disabling use_fused_kernels.") + use_fused_kernels = False + if use_remove_padding or ulysses_sp_size > 1: + from transformers.models.glm_moe_dsa.modeling_glm_moe_dsa import GlmMoeDsaAttention + + from verl.models.transformers.glm_moe_dsa import glm_moe_dsa_attn_forward_with_dsa + + GlmMoeDsaAttention.forward = glm_moe_dsa_attn_forward_with_dsa + print(f"Monkey patch {model.__class__.__name__} attention layer with DSA") + elif model.config.model_type == "kimi_vl": if use_remove_padding or ulysses_sp_size > 1: # TODO: Changes need to be made when transformers are adapted. diff --git a/verl/utils/flops_counter.py b/verl/utils/flops_counter.py index 57fa62030e8..f980c795958 100644 --- a/verl/utils/flops_counter.py +++ b/verl/utils/flops_counter.py @@ -535,6 +535,82 @@ def _estimate_gpt_oss_flops(config, tokens_sum, batch_seqlens, delta_time): return flops_achieved +def _estimate_glm_moe_dsa_flops(config, tokens_sum, batch_seqlens, delta_time): + hidden_size = config.hidden_size + vocab_size = config.vocab_size + moe_intermediate_size = config.moe_intermediate_size + num_hidden_layers = config.num_hidden_layers + num_experts = config.n_routed_experts + + moe_topk = config.num_experts_per_tok + share_expert_num = config.n_shared_experts + + num_attention_heads = config.num_attention_heads + q_lora_rank = config.q_lora_rank + kv_lora_rank = config.kv_lora_rank + qk_head_dim = config.qk_head_dim + qk_rope_head_dim = config.qk_rope_head_dim + qk_nope_head_dim = config.qk_nope_head_dim + v_head_dim = config.v_head_dim + + # DSA indexer config + index_head_dim = config.index_head_dim + index_n_heads = config.index_n_heads + index_topk = config.index_topk + + # MLA attention linear params per layer: + # q_a_proj + q_b_proj + kv_a_proj + kv_b_proj + o_proj + # When q_lora_rank is None, q_proj directly maps hidden_size -> num_attention_heads * qk_head_dim + if q_lora_rank is None: + q_proj_N = hidden_size * (num_attention_heads * qk_head_dim) + indexer_q_N = hidden_size * (index_n_heads * index_head_dim) + else: + q_proj_N = hidden_size * q_lora_rank + q_lora_rank * (num_attention_heads * qk_head_dim) + indexer_q_N = q_lora_rank * (index_n_heads * index_head_dim) + + attn_linear_N = ( + q_proj_N + + hidden_size * (kv_lora_rank + qk_rope_head_dim) + + kv_lora_rank * (num_attention_heads * (qk_nope_head_dim + v_head_dim)) + + (num_attention_heads * v_head_dim) * hidden_size + ) + + # DSA indexer linear params per layer: + # wq_b: q_lora_rank -> index_n_heads * index_head_dim + # wk: kv_lora_rank -> index_n_heads * index_head_dim + # weights_proj: index_n_heads -> num_attention_heads + indexer_linear_N = ( + indexer_q_N + kv_lora_rank * (index_n_heads * index_head_dim) + index_n_heads * num_attention_heads + ) + + moe_gata_N = hidden_size * num_experts + moe_expertmlp_N = hidden_size * moe_intermediate_size * (moe_topk + share_expert_num) * 3 + moe_mlp_N = moe_gata_N + moe_expertmlp_N + + emd_and_lm_head_N = vocab_size * hidden_size * 2 + + dense_N = (moe_mlp_N + attn_linear_N + indexer_linear_N) * num_hidden_layers + emd_and_lm_head_N + + dense_N_flops = 6 * dense_N * tokens_sum + + # DSA attention FLOPs per layer: + # Stage 1 (indexer): Q_idx × K_idx^T over full seq -> seq × seq × index_head_dim × index_n_heads + # Stage 2 (sparse MLA): Q × K^T and attn × V only over topk tokens + # -> seq × topk × (qk_head_dim + v_head_dim) × num_attention_heads + indexer_qk_flops = 0 + sparse_attn_flops = 0 + for seqlen in batch_seqlens: + effective_topk = min(index_topk, seqlen) + indexer_qk_flops += seqlen * seqlen * index_head_dim * index_n_heads + sparse_attn_flops += seqlen * effective_topk * (qk_head_dim + v_head_dim) * num_attention_heads + + attn_qkv_flops = 6 * (indexer_qk_flops + sparse_attn_flops) * num_hidden_layers + + flops_all_token = dense_N_flops + attn_qkv_flops + flops_achieved = flops_all_token * (1.0 / delta_time) / 1e12 + return flops_achieved + + def _estimate_unknown_flops(config, tokens_sum, batch_seqlens, delta_time): return 0 @@ -559,6 +635,7 @@ def _estimate_unknown_flops(config, tokens_sum, batch_seqlens, delta_time): "glm4v": _estimate_qwen2_flops, "gpt_oss": _estimate_gpt_oss_flops, "mimo": _estimate_qwen2_flops, + "glm_moe_dsa": _estimate_glm_moe_dsa_flops, }