From ffd539a1e6ce658c641602b96bae66fc13ecdea8 Mon Sep 17 00:00:00 2001 From: "chucai.dzq" Date: Thu, 28 May 2026 11:51:35 +0800 Subject: [PATCH 1/2] feat(mcore): add GLM-5/DeepSeek-V3 model support (mbridge + megatron-bridge) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Adds the GLM-5.1 / DeepSeek-V3 / GLM-4.7-Flash architecture family, which open-source AReaL did not previously support. Coverage spans both the mbridge path (used by default) and the megatron-bridge path opted into via mcore.bridge_type=megatron-bridge. Stacks on top of the Bailing-MoE megatron-bridge adapter PR, which introduced the shared cross-cutting infrastructure (optional mbridge import wrapping, Bridge type-annotation cleanups, migration doc). New model code: - areal/models/mcore/deepseek_v3.py: HF config -> MLATransformerConfig conversion, homogeneous MLA layer specs, _has_dsa() helper. Handles DeepseekV3ForCausalLM, GlmMoeDsaForCausalLM (GLM-5.1), and Glm4MoeForCausalLM (GLM-4.7-Flash) — all three share the underlying MLA + MoE topology, GLM-5.1 additionally exposes a DSA indexer. - areal/models/mcore/deepseek_v3_bridge.py: mbridge LLMBridge subclass registered as deepseek_v3 / glm_moe_dsa / glm4_moe_lite. - areal/models/mcore/dsa_mla_attention.py: custom DSAMLASelfAttention module that inherits Attention directly (not DSAttention) so packed THD inputs work without modification. Implements the DSA indexer (wq_b, wk, k_norm, weights_proj) called by the layer spec. - areal/models/mcore/glm5_megatron_bridge.py: megatron-bridge MegatronModelBridge subclass for GlmMoeDsaForCausalLM, with DSA indexer weight mappings and MTP layer support. - areal/experimental/ops/dsa/: six tilelang kernel files implementing the DSA indexer forward/backward and the sparse MLA forward/backward used by DSAMLASelfAttention. Specialized for DSV3/GLM-5.1 latent geometry (kv_lora_rank=512, qk_rope_head_dim=64). Registry / engine wiring: - areal/models/mcore/registry.py: add _DEEPSEEK_V3_ARCHITECTURES set and _supplement_dsa_config() helper; register DSV3/GLM-5/GLM-4.7 architectures in make_hf_and_mcore_config (no-bridge fallback) and make_mcore_layer_specs; inject AReaL's DSA-aware layer spec into provider.transformer_layer_spec for the megatron-bridge path when _has_dsa(hf_config) is True. - areal/engine/megatron_engine.py: import deepseek_v3_bridge (lazy via try/except, since it depends on mbridge) and glm5_megatron_bridge (unconditional, since it uses megatron-bridge) so their decorators fire on engine load. Docs: - docs/en/best_practices/migrate_to_megatron_bridge.md: extend the supported-architectures table with DSV3/GLM-5/GLM-4.7 entries and describe the DSA-aware layer spec injection alongside the Bailing one. --- areal/engine/megatron_engine.py | 5 + areal/experimental/ops/__init__.py | 1 + areal/experimental/ops/dsa/__init__.py | 1 + areal/experimental/ops/dsa/indexer.py | 102 +++ areal/experimental/ops/dsa/sparse_mla.py | 163 ++++ .../ops/dsa/tilelang_indexer_bwd.py | 172 +++++ .../ops/dsa/tilelang_indexer_fwd.py | 154 ++++ .../ops/dsa/tilelang_sparse_mla_bwd.py | 408 ++++++++++ .../ops/dsa/tilelang_sparse_mla_fwd.py | 231 ++++++ areal/models/mcore/deepseek_v3.py | 354 +++++++++ areal/models/mcore/deepseek_v3_bridge.py | 378 ++++++++++ areal/models/mcore/dsa_mla_attention.py | 699 ++++++++++++++++++ areal/models/mcore/glm5_megatron_bridge.py | 355 +++++++++ areal/models/mcore/registry.py | 48 +- .../migrate_to_megatron_bridge.md | 51 +- 15 files changed, 3102 insertions(+), 20 deletions(-) create mode 100644 areal/experimental/ops/__init__.py create mode 100644 areal/experimental/ops/dsa/__init__.py create mode 100644 areal/experimental/ops/dsa/indexer.py create mode 100644 areal/experimental/ops/dsa/sparse_mla.py create mode 100644 areal/experimental/ops/dsa/tilelang_indexer_bwd.py create mode 100644 areal/experimental/ops/dsa/tilelang_indexer_fwd.py create mode 100644 areal/experimental/ops/dsa/tilelang_sparse_mla_bwd.py create mode 100644 areal/experimental/ops/dsa/tilelang_sparse_mla_fwd.py create mode 100644 areal/models/mcore/deepseek_v3.py create mode 100644 areal/models/mcore/deepseek_v3_bridge.py create mode 100644 areal/models/mcore/dsa_mla_attention.py create mode 100644 areal/models/mcore/glm5_megatron_bridge.py diff --git a/areal/engine/megatron_engine.py b/areal/engine/megatron_engine.py index 6f7ccc7f3d..6f42684386 100644 --- a/areal/engine/megatron_engine.py +++ b/areal/engine/megatron_engine.py @@ -43,8 +43,13 @@ import areal.models.mcore.bailing_moe_bridge # noqa: F401 # register bridge except ImportError: pass +try: + import areal.models.mcore.deepseek_v3_bridge # noqa: F401 # register bridge +except ImportError: + pass # megatron-bridge adapters do not depend on mbridge and register on import. import areal.models.mcore.bailing_moe_megatron_bridge # noqa: F401 # register bridge +import areal.models.mcore.glm5_megatron_bridge # noqa: F401 # register bridge from areal.api import ( FinetuneSpec, InferenceEngine, diff --git a/areal/experimental/ops/__init__.py b/areal/experimental/ops/__init__.py new file mode 100644 index 0000000000..9881313609 --- /dev/null +++ b/areal/experimental/ops/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/areal/experimental/ops/dsa/__init__.py b/areal/experimental/ops/dsa/__init__.py new file mode 100644 index 0000000000..9881313609 --- /dev/null +++ b/areal/experimental/ops/dsa/__init__.py @@ -0,0 +1 @@ +# SPDX-License-Identifier: Apache-2.0 diff --git a/areal/experimental/ops/dsa/indexer.py b/areal/experimental/ops/dsa/indexer.py new file mode 100644 index 0000000000..d0067972e3 --- /dev/null +++ b/areal/experimental/ops/dsa/indexer.py @@ -0,0 +1,102 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os + +import torch + +from .tilelang_indexer_bwd import indexer_bwd_interface +from .tilelang_indexer_fwd import indexer_fwd_interface + +# DSA indexer topk_indices recording for routing comparison. +# When AREAL_DUMP_ROUTING is set, each layer's forward appends its +# topk_indices here. Cleared after dump in megatron_engine.py. +_recorded_dsa_indices: list[torch.Tensor] = [] +_record_dsa = bool(os.environ.get("AREAL_DUMP_ROUTING", "")) + + +def get_recorded_dsa_indices() -> list[torch.Tensor]: + return _recorded_dsa_indices + + +def clear_recorded_dsa_indices(): + _recorded_dsa_indices.clear() + + +def pytorch_extract_topk_scores(logits, topk_indices, dim=-1): + valid_mask = topk_indices != -1 + safe_indices = topk_indices.clamp(min=0).to(torch.int64) + scores = torch.gather(logits, dim=dim, index=safe_indices) + scores = torch.where(valid_mask, scores, float("-inf")) + return scores + + +class IndexerFunction(torch.autograd.Function): + @staticmethod + def forward( + ctx, + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk: int, + topk_indices: torch.Tensor | None = None, + ): + _, head_num, _ = index_q.shape + logits = indexer_fwd_interface( + index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True + ) + if topk_indices is None: + sorted_indices = torch.argsort(-logits, dim=-1, stable=True) + topk_indices = sorted_indices[..., :topk].to(torch.int32) + index_score = torch.gather( + logits, dim=-1, index=topk_indices.to(torch.int64) + ) + topk_indices = topk_indices.masked_fill(index_score == -torch.inf, -1) + + index_score = pytorch_extract_topk_scores(logits, topk_indices) + + if _record_dsa: + _recorded_dsa_indices.append(topk_indices.detach().cpu()) + + ctx.save_for_backward( + index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, topk_indices + ) + ctx.topk = topk + ctx.head_num = head_num + return index_score, topk_indices + + @staticmethod + def backward(ctx, grad_scores, grad_indices): + index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, topk_indices = ( + ctx.saved_tensors + ) + grad_q, grad_w, grad_k = indexer_bwd_interface( + index_q, weights, index_k, topk_indices, grad_scores + ) + return grad_q, grad_k, grad_w, None, None, None, None, None, None, None + + +def lighting_indexer( + index_q: torch.Tensor, + index_k: torch.Tensor, + weights: torch.Tensor, + cu_seqlen_ks: torch.Tensor, + cu_seqlen_ke: torch.Tensor, + topk: int, + topk_indices: torch.Tensor | None = None, +): + weights = weights.squeeze(-1) + return IndexerFunction.apply( + index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, topk, topk_indices + ) + + +def generate_varlen_mask_params(cu_seqlens): + seq_len = cu_seqlens[-1].item() + q_indices = torch.arange(0, seq_len, device=cu_seqlens.device) + seq_indices = torch.searchsorted(cu_seqlens, q_indices, right=True) - 1 + starts = cu_seqlens[seq_indices] + ends = q_indices + 1 + assert torch.all((ends - starts) > 0) + return starts, ends diff --git a/areal/experimental/ops/dsa/sparse_mla.py b/areal/experimental/ops/dsa/sparse_mla.py new file mode 100644 index 0000000000..56f63ff62d --- /dev/null +++ b/areal/experimental/ops/dsa/sparse_mla.py @@ -0,0 +1,163 @@ +# SPDX-License-Identifier: Apache-2.0 + +import os + +import torch + +from .tilelang_sparse_mla_bwd import sparse_mla_bwd +from .tilelang_sparse_mla_fwd import sparse_mla_fwd_interface + + +def _pytorch_sparse_mla_bwd(q, kv, tl_out, grad_output, indices, tl_lse, scaling): + """G11s: chunked pure-pytorch SparseMLA backward. + + Memory: materializing k_gathered at shape (S, G, TOPK, D_full) can be + ~22 GB per rank. Chunk along S dimension (64 rows at a time ≈ 300 MB). + + Math correctness: kernel's tl_lse is log2(sum_i exp(score_i * scaling)). + softmax prob = exp(score_i * scaling - tl_lse * ln(2)). + """ + S, H, D_full = q.shape + S_kv, G, _ = kv.shape + D_v = tl_out.shape[-1] + TOPK = indices.shape[-1] + H_per_group = H // G + ln2 = 0.6931471805599453 + + q_f = q.float() + kv_f = kv.float() + do_f = grad_output.float() + o_f = tl_out.float() + + # Precompute Delta = sum_d o * do (small, fp32) + delta_full = (o_f * do_f).sum(dim=-1) # (S, H) + + # Outputs + dq = torch.zeros(S, H, D_full, device=q.device, dtype=torch.float32) + dkv_fp32 = torch.zeros(S_kv, G, D_full, device=q.device, dtype=torch.float32) + + CHUNK = 32 + for s0 in range(0, S, CHUNK): + s1 = min(s0 + CHUNK, S) + cs = s1 - s0 # chunk size + idx_c = indices[s0:s1] # (cs, G, TOPK) + safe_c = idx_c.clamp(min=0).long() # (cs, G, TOPK) + valid_c = idx_c != -1 # (cs, G, TOPK) + + # Gather K per group: (cs, G, TOPK, D_full) + k_c = torch.stack( + [ + kv_f[:, g, :] + .index_select(0, safe_c[:, g, :].reshape(-1)) + .view(cs, TOPK, D_full) + for g in range(G) + ], + dim=1, + ) + # Compute per-head via per-group q split (avoid 8x replication) + q_c = q_f[s0:s1].view(cs, G, H_per_group, D_full) # (cs, G, Hg, D_full) + do_c = do_f[s0:s1].view(cs, G, H_per_group, D_v) # (cs, G, Hg, D_v) + v_c = k_c[..., :D_v] # (cs, G, TOPK, D_v) + lse_c = tl_lse[s0:s1].view(cs, G, H_per_group) # (cs, G, Hg) + delta_c = delta_full[s0:s1].view(cs, G, H_per_group) # (cs, G, Hg) + + # score[cs, g, hg, t] = q_c @ k_c + score = torch.einsum("cghd,cgtd->cght", q_c, k_c) * scaling + # valid mask broadcast + valid_chg = valid_c.unsqueeze(2) # (cs, G, 1, TOPK) + score = torch.where( + valid_chg, + score, + torch.tensor(-1e30, device=score.device, dtype=score.dtype), + ) + # prob = exp(score - lse * ln(2)) + prob = torch.exp(score - lse_c.unsqueeze(-1) * ln2) + prob = torch.where(valid_chg, prob, torch.zeros_like(prob)) + del score + + # dp_raw = do @ V^T + dp_raw = torch.einsum("cghd,cgtd->cght", do_c, v_c) + dp = prob * (dp_raw - delta_c.unsqueeze(-1)) * scaling + dp = torch.where(valid_chg, dp, torch.zeros_like(dp)) + del dp_raw + + # dq[cs, g, hg, d] = sum_t dp * K + dq_c = torch.einsum("cght,cgtd->cghd", dp, k_c).view(cs, H, D_full) + dq[s0:s1] = dq_c + del dq_c + + # dkv scatter per group + for g in range(G): + idx_g = safe_c[:, g, :].reshape(-1) # (cs*TOPK,) + # K contribution: (cs, Hg, TOPK) dp @ q + dp_g = dp[:, g, :, :] # (cs, Hg, TOPK) + q_g = q_c[:, g, :, :] # (cs, Hg, D_full) + prob_g = prob[:, g, :, :] # (cs, Hg, TOPK) + do_g = do_c[:, g, :, :] # (cs, Hg, D_v) + valid_gm = valid_c[:, g, :].unsqueeze(-1) # (cs, TOPK, 1) + k_contrib = torch.einsum("cht,chd->ctd", dp_g, q_g) # (cs, TOPK, D_full) + v_contrib = torch.einsum("cht,chd->ctd", prob_g, do_g) # (cs, TOPK, D_v) + k_contrib = torch.where(valid_gm, k_contrib, torch.zeros_like(k_contrib)) + v_contrib = torch.where(valid_gm, v_contrib, torch.zeros_like(v_contrib)) + dkv_fp32[:, g, :].index_add_(0, idx_g, k_contrib.reshape(-1, D_full)) + dkv_fp32[:, g, :D_v].index_add_(0, idx_g, v_contrib.reshape(-1, D_v)) + del k_contrib, v_contrib + del prob, dp, k_c, v_c + + return dq.to(q.dtype).contiguous(), dkv_fp32.contiguous() + + +class SparseMLA(torch.autograd.Function): + @staticmethod + def forward(ctx, q, kv, indices, scaling): + """ + Args: + q: Query tensor (seq_len, heads, dim_plus_tail_dim) + kv: Key-Value tensor (seq_len_kv, kv_group, dim_plus_tail_dim) + indices: Sparse indices tensor (seq_len, kv_group, topk) + + Returns: + out: Output tensor (seq_len, heads, dim) + """ + indices = indices.contiguous() + q, kv = q.contiguous(), kv.contiguous() + ctx.scaling = scaling + tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, sm_scale=scaling) + + # Save tensors for backward pass + ctx.save_for_backward(q, kv, indices, tl_out, tl_lse) + + return tl_out, tl_lse + + @staticmethod + def backward(ctx, grad_output, grad_lse): + """ + Args: + grad_output: Gradient of the loss with respect to output + + Returns: + Gradients for q, kv, and indices (None for indices) + """ + q, kv, indices, tl_out, tl_lse = ctx.saved_tensors + scaling = ctx.scaling + + # G11r: pure-pytorch fallback when AREAL_SPARSE_MLA_PYTORCH_BWD=1. + # Used to bypass the L20X TileLang bwd kernel NaN bug while we figure + # out the kernel cache invalidation problem. + if os.environ.get("AREAL_SPARSE_MLA_PYTORCH_BWD", "0") == "1": + tl_dq, tl_dkv = _pytorch_sparse_mla_bwd( + q, kv, tl_out, grad_output, indices, tl_lse, scaling + ) + else: + tl_dq, tl_dkv = sparse_mla_bwd( + q, + kv, + tl_out, + grad_output.contiguous(), + indices, + tl_lse, + sm_scale=scaling, + ) + + # Return gradients for each input (None for indices as it's not differentiable) + return tl_dq, tl_dkv, None, None diff --git a/areal/experimental/ops/dsa/tilelang_indexer_bwd.py b/areal/experimental/ops/dsa/tilelang_indexer_bwd.py new file mode 100644 index 0000000000..c6dee14b74 --- /dev/null +++ b/areal/experimental/ops/dsa/tilelang_indexer_bwd.py @@ -0,0 +1,172 @@ +# SPDX-License-Identifier: Apache-2.0 + +# ruff: noqa +# Adapted from https://github.com/tile-ai/tilelang/blob/4956b5835fa554af6c03d4a6289cad44bf310869/examples/dsa_sparse_finetune/indexer_bwd.py +import tilelang as tl +import tilelang.language as T +import torch + +BF16 = T.bfloat16 +FP32 = T.float32 +INT32 = T.int32 + +pass_configs = { + tl.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tl.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, +} + + +@tl.jit(pass_configs=pass_configs) +def tl_indexer_bwd_impl( + heads: int, + dim: int, + topk: int, + block_I: int = 32, + num_stages: int = 0, + num_threads: int = 128, +): + assert num_stages == 0 + assert topk == tl.math.next_power_of_2(topk) + assert topk % block_I == 0 + assert heads <= 64 and heads % 8 == 0 + seq_len = T.symbolic("seq_len") + q_seq_len = T.symbolic("q_seq_len") + + dtype: str = BF16 + accum_dtype: str = FP32 + index_q_shape = [q_seq_len, heads, dim] + weights_shape = [q_seq_len, heads] + index_k_shape = [seq_len, dim] + shape_p = [q_seq_len, topk] + topk_indices_shape = [q_seq_len, topk] + + pad_heads = heads + if heads < 16: + pad_heads = 16 + + @T.prim_func + def tl_indexer_bwd_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), + IndexK: T.Tensor(index_k_shape, dtype), + Weights: T.Tensor(weights_shape, FP32), + TopkIndices: T.Tensor(topk_indices_shape, INT32), + OGrad: T.Tensor(shape_p, FP32), + dIndexQ: T.Tensor(index_q_shape, dtype), + dWeights: T.Tensor(weights_shape, FP32), + dIndexK: T.Tensor(index_k_shape, FP32), + ): + with T.Kernel(q_seq_len, threads=num_threads) as (bx): + index_q_shared = T.alloc_shared([pad_heads, dim], dtype=FP32) + weights_shared = T.alloc_shared([pad_heads], dtype=FP32) + index_k_shared = T.alloc_shared([block_I, dim], dtype=FP32) + indices_shared = T.alloc_shared([block_I], dtype=INT32) + d_index_q_frag = T.alloc_fragment([pad_heads, dim], dtype=accum_dtype) + d_weights_frag = T.alloc_fragment([pad_heads], dtype=accum_dtype) + d_index_k_frag = T.alloc_fragment([block_I, dim], dtype=accum_dtype) + logits = T.alloc_fragment((block_I, pad_heads), dtype=accum_dtype) + _logits = T.alloc_shared((block_I, pad_heads), dtype=accum_dtype) + grad = T.alloc_shared([block_I], dtype=FP32) + + num_blocks = T.ceildiv(topk, block_I) + for i, j in T.Parallel(pad_heads, dim): + index_q_shared[i, j] = T.if_then_else(i < heads, IndexQ[bx, i, j], 0) + for i in T.Parallel(heads): + weights_shared[i] = Weights[bx, i] + + T.fill(d_index_q_frag, 0) + T.fill(d_weights_frag, 0) + + # for bi_i in T.Pipelined(num_blocks, num_stages=num_stages): + for bi_i in T.serial(num_blocks): + for i in T.Parallel(block_I): + if bi_i * block_I + i < topk: + indices_shared[i] = TopkIndices[bx, bi_i * block_I + i] + grad[i] = OGrad[bx, bi_i * block_I + i] + + T.sync_threads() + for i, j in T.Parallel(block_I, dim): + index_k_shared[i, j] = T.if_then_else( + indices_shared[i] > -1 and indices_shared[i] < seq_len, + IndexK[indices_shared[i], j], + 0, + ) + + T.sync_threads() + T.gemm( + index_k_shared, + index_q_shared, + logits, + transpose_A=False, + transpose_B=True, + clear_accum=True, + ) + for i, j in T.Parallel(block_I, heads): + logits[i, j] = T.max(logits[i, j], 0) + + d_weights_i = T.alloc_fragment((block_I, pad_heads), accum_dtype) + for i, j in T.Parallel(block_I, heads): + d_weights_i[i, j] = grad[i] * logits[i, j] + T.reduce_sum(d_weights_i, d_weights_frag, dim=0, clear=False) + + for i, j in T.Parallel(block_I, pad_heads): + _logits[i, j] = T.if_then_else( + logits[i, j] > 0 and j < heads, grad[i] * weights_shared[j], 0 + ) + T.sync_threads() + T.gemm( + _logits, + index_k_shared, + d_index_q_frag, + transpose_A=True, + transpose_B=False, + clear_accum=False, + ) + + T.gemm( + _logits, + index_q_shared, + d_index_k_frag, + transpose_A=False, + transpose_B=False, + clear_accum=True, + ) + + for i, j in T.Parallel(block_I, dim): + if indices_shared[i] > -1 and indices_shared[i] < seq_len: + T.atomic_add( + dIndexK[indices_shared[i], j], d_index_k_frag[i, j] + ) + + T.copy(d_index_q_frag[:heads, :], dIndexQ[bx, :, :]) + T.copy(d_weights_frag[:heads], dWeights[bx, :]) + + return tl_indexer_bwd_kernel + + +def indexer_bwd_interface( + index_q: torch.Tensor, + weights: torch.Tensor, + index_k: torch.Tensor, + topk_indices: torch.Tensor, + grad_scores: torch.Tensor, +): + _, head_num, head_dim = index_q.shape + k_top = topk_indices.shape[1] + + grad_scores = grad_scores.contiguous() + grad_q = torch.empty_like(index_q) + grad_w = torch.empty_like(weights, dtype=torch.float32) + grad_k = torch.zeros_like(index_k, dtype=torch.float32) + + tl_indexer_bwd_impl(head_num, head_dim, k_top)( + index_q.contiguous(), + index_k.contiguous(), + weights.squeeze(-1).contiguous(), + topk_indices.contiguous(), + grad_scores, + grad_q, + grad_w.squeeze(-1), + grad_k, + ) + + return grad_q, grad_w, grad_k diff --git a/areal/experimental/ops/dsa/tilelang_indexer_fwd.py b/areal/experimental/ops/dsa/tilelang_indexer_fwd.py new file mode 100644 index 0000000000..0861d97873 --- /dev/null +++ b/areal/experimental/ops/dsa/tilelang_indexer_fwd.py @@ -0,0 +1,154 @@ +# SPDX-License-Identifier: Apache-2.0 + +# ruff: noqa +# Adapted from https://github.com/tile-ai/tilelang/blob/4956b5835fa554af6c03d4a6289cad44bf310869/examples/deepseek_v32/fp8_lighting_indexer.py +import tilelang +import torch +from tilelang import language as T + + +@tilelang.jit( + pass_configs={ + tilelang.PassConfigKey.TL_ENABLE_FAST_MATH: True, + }, +) +def tl_indexer_fwd_impl( + heads, + index_dim, + block_N=256, + num_stages=3, + threads=512, + block_Q=None, +): + if block_Q is None: + block_Q = 128 // heads + dtype = T.bfloat16 + accum_dtype = T.float32 + index_dtype = T.int32 + + seq_len = T.dynamic("seq_len") + seq_len_kv = T.dynamic("seq_len_kv") + + index_q_shape = [seq_len * heads, index_dim] + index_k_shape = [seq_len_kv, index_dim] + logits_shape = [seq_len, seq_len_kv] + + @T.prim_func + def tl_indexer_fwd_kernel( + IndexQ: T.Tensor(index_q_shape, dtype), # type: ignore + IndexK: T.Tensor(index_k_shape, dtype), # type: ignore + Logits: T.Tensor(logits_shape, accum_dtype), # type: ignore + Weights: T.Tensor([seq_len, heads], accum_dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], index_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], index_dtype), # type: ignore + ): + with T.Kernel(T.ceildiv(seq_len, block_Q), threads=threads) as bx: + index_q_shared = T.alloc_shared([block_Q * heads, index_dim], dtype) + index_k_shared = T.alloc_shared([block_N, index_dim], dtype) + s = T.alloc_fragment([block_N, block_Q * heads], accum_dtype) + s_reshaped = T.reshape(s, (block_N, block_Q, heads)) + logits = T.alloc_fragment([block_N, block_Q], accum_dtype) + weights = T.alloc_fragment([block_Q, heads], accum_dtype) + + seq_len_i = bx * block_Q + + cu_k_s_min = T.alloc_var(index_dtype) + cu_k_e_max = T.alloc_var(index_dtype) + + cu_k_s_min = 2147483647 + cu_k_e_max = -2147483648 + + for bq_i in T.serial(block_Q): + cu_k_s_min = T.min( + cu_k_s_min, T.min(CuSeqLenKS[seq_len_i + bq_i], seq_len_kv) + ) + for bq_i in T.serial(block_Q): + cu_k_e_max = T.max( + cu_k_e_max, T.min(CuSeqLenKE[seq_len_i + bq_i], seq_len_kv) + ) + + T.copy(IndexQ[seq_len_i * heads, 0], index_q_shared) + T.copy(Weights[seq_len_i, 0], weights) + + for nbn_i in T.Pipelined( + T.ceildiv(cu_k_e_max - cu_k_s_min, block_N), num_stages=num_stages + ): + T.copy(IndexK[cu_k_s_min + nbn_i * block_N, 0], index_k_shared) + + T.gemm( + index_k_shared, + index_q_shared, + s, + transpose_B=True, + clear_accum=True, + policy=T.GemmWarpPolicy.FullCol, + ) + + for bn_i, bq_i, h_i in T.Parallel(block_N, block_Q, heads): + s_reshaped[bn_i, bq_i, h_i] = ( + T.max(s_reshaped[bn_i, bq_i, h_i], 0) * weights[bq_i, h_i] + ) + + T.reduce_sum(s_reshaped, logits, dim=-1, clear=True) + + for bq_i, bn_i in T.Parallel(block_Q, block_N): + Logits[seq_len_i + bq_i, cu_k_s_min + nbn_i * block_N + bn_i] = ( + logits[bn_i, bq_i] + ) + + return tl_indexer_fwd_kernel + + +@tilelang.jit +def clean_logits_( + threads: int = 512, + block_K: int = 4096, +): + seq_len = T.dynamic("seq_len") + seq_len_kv = T.dynamic("seq_len_kv") + + dtype = T.float + indices_dtype = T.int32 + + @T.prim_func + def clean_logits_kernel( + Logits: T.Tensor([seq_len, seq_len_kv], dtype), # type: ignore + CuSeqLenKS: T.Tensor([seq_len], indices_dtype), # type: ignore + CuSeqLenKE: T.Tensor([seq_len], indices_dtype), # type: ignore + ): + with T.Kernel(seq_len, threads=threads) as bx: + tx = T.thread_binding(0, threads, thread="threadIdx.x") + cu_k_s = CuSeqLenKS[bx] + cu_k_e = CuSeqLenKE[bx] + + for n_i in T.Pipelined(T.ceildiv(seq_len_kv, block_K)): + for k_i in T.serial(block_K // threads): + idx = n_i * block_K + k_i * threads + tx + if idx < cu_k_s or idx >= cu_k_e: + Logits[bx, idx] = -T.infinity(dtype) + + return clean_logits_kernel + + +def indexer_fwd_interface( + q, kv, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True +): + seq_len, heads, index_dim = q.shape + seq_len_kv = kv.shape[0] + + clean_logits_kernel = clean_logits_() + + tl_indexer_fwd_kernel = tl_indexer_fwd_impl(heads=heads, index_dim=index_dim) + + logits = torch.empty([seq_len, seq_len_kv], device=q.device, dtype=torch.float32) + tl_indexer_fwd_kernel( + q.view(seq_len * heads, index_dim), + kv, + logits, + weights, + cu_seqlen_ks, + cu_seqlen_ke, + ) + if clean_logits: + clean_logits_kernel(logits, cu_seqlen_ks, cu_seqlen_ke) + return logits diff --git a/areal/experimental/ops/dsa/tilelang_sparse_mla_bwd.py b/areal/experimental/ops/dsa/tilelang_sparse_mla_bwd.py new file mode 100644 index 0000000000..f07d7808a4 --- /dev/null +++ b/areal/experimental/ops/dsa/tilelang_sparse_mla_bwd.py @@ -0,0 +1,408 @@ +# SPDX-License-Identifier: Apache-2.0 + +# ruff: noqa +# Adapt from https://github.com/tile-ai/tilelang/blob/4ff81c7d40803d269569e157e847623e84553f78/examples/deepseek_v32/sparse_mla_bwd.py +import tilelang +import torch +from tilelang import language as T + + +@tilelang.jit(out_idx=[-1]) +def preprocess( + B, + S, + H, + D, + block_ND=32, + num_stages=5, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + shape = [B, S, H, D] + + @T.prim_func + def preprocess_kernel( + O: T.Tensor(shape, dtype), + dO: T.Tensor(shape, dtype), + Delta: T.Tensor([B, S, H], accum_dtype), + ): + with T.Kernel(H, T.ceildiv(S, block_ND), B) as (bx, by, bz): + o = T.alloc_fragment([block_ND, block_ND], accum_dtype) + do = T.alloc_fragment([block_ND, block_ND], accum_dtype) + delta = T.alloc_fragment([block_ND], accum_dtype) + acc = T.alloc_fragment([block_ND, block_ND], accum_dtype) + T.clear(acc) + for k in T.Pipelined(T.ceildiv(D, block_ND), num_stages=num_stages): + T.copy( + O[ + bz, + by * block_ND : (by + 1) * block_ND, + bx, + k * block_ND : (k + 1) * block_ND, + ], + o, + ) + T.copy( + dO[ + bz, + by * block_ND : (by + 1) * block_ND, + bx, + k * block_ND : (k + 1) * block_ND, + ], + do, + ) + for i, j in T.Parallel(block_ND, block_ND): + acc[i, j] += o[i, j] * do[i, j] + T.reduce_sum(acc, delta, 1) + T.copy(delta, Delta[bz, by * block_ND : (by + 1) * block_ND, bx]) + + return preprocess_kernel + + +@tilelang.jit(out_idx=[-1]) +def postprocess( + B, + S_kv, + D, + D_tail, + kv_group=1, + block_N=64, + threads=128, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + dkv_shape = [B, S_kv, kv_group, D + D_tail] + + @T.prim_func + def postprocess_kernel( + dKV: T.Tensor(dkv_shape, accum_dtype), + dKV_out: T.Tensor(dkv_shape, dtype), + ): + with T.Kernel(T.ceildiv(S_kv, block_N), kv_group, B, threads=threads) as ( + bx, + by, + bz, + ): + T.copy( + dKV[bz, bx * block_N : (bx + 1) * block_N, by, :], + dKV_out[bz, bx * block_N : (bx + 1) * block_N, by, :], + ) + + return postprocess_kernel + + +@tilelang.jit( + out_idx=[-2], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + tilelang.PassConfigKey.TL_ENABLE_AGGRESSIVE_SHARED_MEMORY_MERGE: True, + }, +) +def bwd( + B, + S, + S_kv, + H, + D, + D_tail, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + block_size=32, + num_stages=0, + threads=128, + indices_dtype=T.int32, + dtype=T.bfloat16, + accum_dtype=T.float32, +): + assert is_causal == True, "non-casual is not supported now" + assert topk % block_size == 0, ( + "otherwise will load some index=0 thus causing wrong kv to be loaded" + ) + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + assert indices_dtype == T.int32 + + if sm_scale is None: + sm_scale = (D + D_tail) ** (-0.5) + sm_scale_mul_reciprocal_log2 = sm_scale * 1.44269504 # log2(e) + + H_kv = H // kv_group + q_shape = [B, S, H, D + D_tail] + k_shape = [B, S_kv, kv_group, D + D_tail] + o_shape = [B, S, H, D] + indices_shape = [B, S, kv_group, topk] + delta_shape = [B, S, H] + lse_shape = [B, S, H] + assert indices_dtype == T.int32 + assert dtype == T.bfloat16 + assert accum_dtype == T.float32 + + H = H_kv + padded_H = max(tilelang.math.next_power_of_2(H_kv), 16) + block_H = min(64, padded_H) + assert padded_H % block_H == 0 + NH = padded_H // block_H + BS = block_size + NS = tilelang.cdiv(topk, block_size) + + split_store = 2 + + @T.prim_func + def sparse_mla_bwd_kernel( + Q: T.Tensor(q_shape, dtype), + KV: T.Tensor(k_shape, dtype), + dO: T.Tensor(o_shape, dtype), + Indices: T.Tensor(indices_shape, indices_dtype), + Lse: T.Tensor(lse_shape, accum_dtype), + Delta: T.Tensor(delta_shape, accum_dtype), + dQ: T.Tensor(q_shape, dtype), + dKV: T.Tensor(k_shape, accum_dtype), + ): + with T.Kernel(S, B, kv_group * NH, threads=threads) as (s_i, by, bz): + Q_shared = T.alloc_shared([block_H, D], dtype) + Q_tail_shared = T.alloc_shared([block_H, D_tail], dtype) + KV_shared = T.alloc_shared([BS, D], dtype) + KV_tail_shared = T.alloc_shared([BS, D_tail], dtype) + dO_shared = T.alloc_shared([block_H, D], dtype) + mask = T.alloc_fragment([BS], "bool") + + P_shared_cast = T.alloc_shared([block_H, BS], dtype) + dP_shared_cast = T.alloc_shared([block_H, BS], dtype) + dQ_shared = T.alloc_shared([block_H, D], dtype) + dQ_tail_shared = T.alloc_shared([block_H, D_tail], dtype) + + acc_p = T.alloc_fragment([block_H, BS], accum_dtype) + acc_dp = T.alloc_fragment([block_H, BS], accum_dtype) + acc_dq = T.alloc_fragment([block_H, D], accum_dtype) + acc_dq_tail = T.alloc_fragment([block_H, D_tail], accum_dtype) + acc_dkv = T.alloc_fragment([BS, D], accum_dtype) + acc_dkv_tail = T.alloc_fragment([BS, D_tail], accum_dtype) + acc_dkv_shared = T.alloc_shared([BS // split_store, D], accum_dtype) + acc_dkv_tail_shared = T.alloc_shared( + [BS // split_store, D_tail], accum_dtype + ) + + # max_kv_i = s_i + + T.copy(Q[by, s_i, bz * block_H : (bz + 1) * block_H, :D], Q_shared) + T.copy(Q[by, s_i, bz * block_H : (bz + 1) * block_H, D:], Q_tail_shared) + T.copy(dO[by, s_i, bz * block_H : (bz + 1) * block_H, :D], dO_shared) + + T.clear(acc_dq) + T.clear(acc_dq_tail) + + # Process each block of indices + for i_i in T.Pipelined(NS, num_stages=num_stages): + # Check which indices are valid + for bi_i in T.Parallel(BS): + # Changed here for thd + mask[bi_i] = Indices[by, s_i, bz // NH, i_i * BS + bi_i] != -1 + + # Compute attention scores + for h_i, bi_i in T.Parallel(block_H, BS): + acc_p[h_i, bi_i] = T.if_then_else( + mask[bi_i], 0, -T.infinity(acc_p.dtype) + ) + + # Load KV, V for this block of indices + for bi_i, d_i in T.Parallel(BS, D): + KV_shared[bi_i, d_i] = KV[ + by, Indices[by, s_i, bz // NH, i_i * BS + bi_i], bz // NH, d_i + ] + + T.gemm( + Q_shared, + KV_shared, + acc_p, + transpose_B=True, + policy=T.GemmWarpPolicy.FullCol, + ) + + for bi_i, d_i in T.Parallel(BS, D_tail): + KV_tail_shared[bi_i, d_i] = KV[ + by, + Indices[by, s_i, bz // NH, i_i * BS + bi_i], + bz // NH, + D + d_i, + ] + T.gemm( + Q_tail_shared, + KV_tail_shared, + acc_p, + transpose_B=True, + policy=T.GemmWarpPolicy.FullCol, + ) + + for h_i, bi_i in T.Parallel(block_H, BS): + acc_p[h_i, bi_i] = T.exp2( + acc_p[h_i, bi_i] * sm_scale_mul_reciprocal_log2 + - Lse[by, s_i, bz * block_H + h_i] + ) + + T.copy(acc_p, P_shared_cast) + + T.gemm( + dO_shared, + KV_shared, + acc_dp, + transpose_B=True, + policy=T.GemmWarpPolicy.FullCol, + clear_accum=True, + ) + + for h_i, bi_i in T.Parallel(block_H, BS): + acc_dp[h_i, bi_i] = ( + acc_p[h_i, bi_i] + * (acc_dp[h_i, bi_i] - Delta[by, s_i, bz * block_H + h_i]) + * sm_scale + ) + + T.copy(acc_dp, dP_shared_cast) + T.gemm( + dP_shared_cast, KV_shared, acc_dq, policy=T.GemmWarpPolicy.FullCol + ) + T.gemm( + dP_shared_cast, + KV_tail_shared, + acc_dq_tail, + policy=T.GemmWarpPolicy.FullCol, + ) + + T.gemm( + dP_shared_cast, + Q_shared, + acc_dkv, + transpose_A=True, + policy=T.GemmWarpPolicy.FullCol, + clear_accum=True, + ) + T.gemm( + P_shared_cast, + dO_shared, + acc_dkv, + transpose_A=True, + policy=T.GemmWarpPolicy.FullCol, + ) + + T.clear(acc_dkv_tail) + T.gemm( + dP_shared_cast, + Q_tail_shared, + acc_dkv_tail, + transpose_A=True, + policy=T.GemmWarpPolicy.FullCol, + ) + + for s in range(split_store): + for bi_i, d_i in T.Parallel(BS, D): + if bi_i < BS // split_store: + acc_dkv_shared[bi_i, d_i] = acc_dkv[ + bi_i + s * (BS // split_store), d_i + ] + + for bi_i, d_i in T.Parallel(BS, D_tail): + if bi_i < BS // split_store: + acc_dkv_tail_shared[bi_i, d_i] = acc_dkv_tail[ + bi_i + s * (BS // split_store), d_i + ] + + for bi_i, d_i in T.Parallel(BS // split_store, D // 4): + T.atomic_addx4( + dKV[ + by, + Indices[ + by, + s_i, + bz // NH, + i_i * BS + bi_i + s * (BS // split_store), + ], + bz // NH, + d_i * 4, + ], + acc_dkv_shared[bi_i, d_i * 4], + ) + + # Atomically update dKV, dKV_tail tensors + for bi_i, d_i in T.Parallel(BS // split_store, D_tail // 4): + T.atomic_addx4( + dKV[ + by, + Indices[ + by, + s_i, + bz // NH, + i_i * BS + bi_i + s * (BS // split_store), + ], + bz // NH, + D + d_i * 4, + ], + acc_dkv_tail_shared[bi_i, d_i * 4], + ) + + # Store the accumulated dQ + T.copy(acc_dq, dQ_shared) + T.copy(acc_dq_tail, dQ_tail_shared) + + T.copy(dQ_shared, dQ[by, s_i, bz * block_H : (bz + 1) * block_H, :D]) + T.copy(dQ_tail_shared, dQ[by, s_i, bz * block_H : (bz + 1) * block_H, D:]) + + return sparse_mla_bwd_kernel + + +def sparse_mla_bwd( + q, + kv, + o, + do, + indices, + lse, + sm_scale=None, + is_casual=True, + return_kernel=False, + delta=None, +): + q = q.unsqueeze(0) + kv = kv.unsqueeze(0) + o = o.unsqueeze(0) + do = do.unsqueeze(0) + indices = indices.unsqueeze(0) + lse = lse.unsqueeze(0) + + assert q.is_contiguous() + assert kv.is_contiguous() + assert indices.is_contiguous() + assert lse.is_contiguous() + B, S, H, dim_plus_tail_dim = q.shape + _, S_kv, kv_group, _ = kv.shape + assert kv.shape[-1] == dim_plus_tail_dim + assert kv.shape[0] == B + # dim should be assigned + D = 512 + + D_tail = dim_plus_tail_dim - D + topk = indices.shape[-1] + assert indices.shape == (B, S, kv_group, topk) + assert lse.shape == (B, S, H) + + # Get kernels + preprocess_kernel = preprocess(B, S, H, D) + bwd_kernel = bwd(B, S, S_kv, H, D, D_tail, topk, kv_group, sm_scale, is_casual) + postprocess_kernel = postprocess(B, S_kv, D, D_tail, kv_group) + + if delta is None: + delta = preprocess_kernel(o, do) + dkv = torch.zeros_like(kv, dtype=torch.float32) + dq = bwd_kernel(q, kv, do, indices, lse, delta, dkv) + dkv = postprocess_kernel(dkv) + + dq = dq.squeeze(0) + dkv = dkv.squeeze(0) + + return dq, dkv diff --git a/areal/experimental/ops/dsa/tilelang_sparse_mla_fwd.py b/areal/experimental/ops/dsa/tilelang_sparse_mla_fwd.py new file mode 100644 index 0000000000..45134cdf0a --- /dev/null +++ b/areal/experimental/ops/dsa/tilelang_sparse_mla_fwd.py @@ -0,0 +1,231 @@ +# SPDX-License-Identifier: Apache-2.0 + +# ruff: noqa +# Adapted from https://github.com/tile-ai/tilelang/blob/e666d2d3cc483829c57618c9ebf2e4f4ada0819d/examples/deepseek_v32/sparse_mla_fwd.py +import tilelang +from tilelang import language as T + + +@tilelang.jit( + out_idx=[-2, -1], + pass_configs={ + tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, + tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True, + }, +) +def sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group=1, + sm_scale=None, + is_causal=True, + CP0=True, + block_I=64, + num_stages=2, + threads=256, +): + assert dim == tilelang.math.next_power_of_2(dim), ( + f"haven't check padding correctness yet, dim={dim}" + ) + assert tail_dim == tilelang.math.next_power_of_2(tail_dim), ( + f"haven't check padding correctness yet, dim={tail_dim}" + ) + assert is_causal == True, "non-casual is not supported" + assert topk % block_I == 0, ( + "otherwise will load some index=0 thus causing wrong kv to be loaded" + ) + if sm_scale is None: + sm_scale = (1.0 / (dim + tail_dim)) ** 0.5 * 1.44269504 # log2(e) + else: + sm_scale = sm_scale * 1.44269504 # log2(e) + + batch = T.dynamic("batch") + seq_len = T.dynamic("seq_len") + seq_len_kv = T.dynamic("seq_len_kv") + + head_kv = heads // kv_group + q_shape = [batch, seq_len, heads, dim + tail_dim] + kv_shape = [batch, seq_len_kv, kv_group, dim + tail_dim] + o_shape = [batch, seq_len, heads, dim] + indices_shape = [batch, seq_len, kv_group, topk] + lse_shape = [batch, seq_len, heads] + indices_dtype = T.int32 + dtype = T.bfloat16 + accum_dtype = T.float32 + + G = kv_group + H = head_kv + padded_H = max(tilelang.math.next_power_of_2(head_kv), 16) + if padded_H != H: + assert kv_group == 1, ( + "here we solve the H padding automatically, other wise you should handle Q copy and Output copy with your mask (when kv_group == 1, use g_i * padded_H:(g_i+1) * padded_H would be handled automatically)" + ) + BI = block_I + NI = tilelang.cdiv(topk, block_I) + D = dim + D_tail = tail_dim + + if head_kv > 64: + assert head_kv % 64 == 0, "head_kv should be a multiple of 64" + REPLICATE_H = head_kv // 64 + else: + REPLICATE_H = 1 + + H_per_block = padded_H if REPLICATE_H == 1 else 64 + + @T.prim_func + def main( + Q: T.Tensor(q_shape, dtype), # type: ignore + KV: T.Tensor(kv_shape, dtype), # type: ignore + Indices: T.Tensor(indices_shape, indices_dtype), # type: ignore + Output: T.Tensor(o_shape, dtype), # type: ignore + Lse: T.Tensor(lse_shape, accum_dtype), # type: ignore + ): + with T.Kernel(seq_len * REPLICATE_H, batch, kv_group, threads=threads) as ( + bx, + by, + bz, + ): + Q_shared = T.alloc_shared([H_per_block, D], dtype) + Q_tail_shared = T.alloc_shared([H_per_block, D_tail], dtype) + KV_shared = T.alloc_shared([BI, D], dtype) + K_tail_shared = T.alloc_shared([BI, D_tail], dtype) + O_shared = T.alloc_shared([H_per_block, D], dtype) + Lse_shared = T.alloc_shared([H_per_block], accum_dtype) + mask = T.alloc_fragment([BI], "bool") + + acc_o = T.alloc_fragment([H_per_block, D], accum_dtype) + acc_s = T.alloc_fragment([H_per_block, BI], accum_dtype) + S_shared = T.alloc_shared([H_per_block, BI], dtype) + sumexp = T.alloc_fragment([H_per_block], accum_dtype) + sumexp_i = T.alloc_fragment([H_per_block], accum_dtype) + alpha = T.alloc_fragment([H_per_block], accum_dtype) + m_i = T.alloc_fragment([H_per_block], accum_dtype) + m_i_prev = T.alloc_fragment([H_per_block], accum_dtype) + + T.fill(acc_o, 0) + T.fill(sumexp, 0) + T.fill(m_i, -(2**30)) # avoid -inf - inf to cause nan + + b_i, g_i = by, bz + s_i = bx if REPLICATE_H == 1 else (bx // REPLICATE_H) + q_i = s_i + max_kv_i = q_i + + H0 = g_i * padded_H + (0 if REPLICATE_H == 1 else (bx % REPLICATE_H) * 64) + H1 = H0 + H_per_block + + T.copy(Q[b_i, s_i, H0:H1, :D], Q_shared) + T.copy(Q[b_i, s_i, H0:H1, D:], Q_tail_shared) + + for i_i in T.Pipelined(NI, num_stages=num_stages): + for bi_i in T.Parallel(BI): + # Changed here for thd + mask[bi_i] = Indices[b_i, s_i, g_i, i_i * BI + bi_i] != -1 + + for bi_i, d_i in T.Parallel(BI, D): + KV_shared[bi_i, d_i] = KV[ + b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, d_i + ] + for bi_i, d_i in T.Parallel(BI, D_tail): + K_tail_shared[bi_i, d_i] = KV[ + b_i, Indices[b_i, s_i, g_i, i_i * BI + bi_i], g_i, D + d_i + ] + + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.if_then_else( + mask[bi_i], 0, -T.infinity(acc_s.dtype) + ) + T.gemm( + Q_shared, + KV_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.gemm( + Q_tail_shared, + K_tail_shared, + acc_s, + transpose_B=True, + policy=T.GemmWarpPolicy.FullRow, + ) + T.copy(m_i, m_i_prev) + T.reduce_max(acc_s, m_i, dim=1, clear=False) + for h_i in T.Parallel(H_per_block): + m_i[h_i] = T.max(m_i[h_i], m_i_prev[h_i]) + for h_i in T.Parallel(H_per_block): + alpha[h_i] = T.exp2((m_i_prev[h_i] - m_i[h_i]) * sm_scale) + for h_i, bi_i in T.Parallel(H_per_block, BI): + acc_s[h_i, bi_i] = T.exp2( + acc_s[h_i, bi_i] * sm_scale - m_i[h_i] * sm_scale + ) + T.reduce_sum(acc_s, sumexp_i, dim=1) # is this a accumulate operator? + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = sumexp[h_i] * alpha[h_i] + sumexp_i[h_i] + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] = acc_o[h_i, d_i] * alpha[h_i] + + T.copy(acc_s, S_shared) + T.gemm(S_shared, KV_shared, acc_o, policy=T.GemmWarpPolicy.FullRow) + + # Rescale + for h_i, d_i in T.Parallel(H_per_block, D): + acc_o[h_i, d_i] /= sumexp[h_i] + for h_i in T.Parallel(H_per_block): + sumexp[h_i] = T.log2(sumexp[h_i]) + m_i[h_i] * sm_scale + + T.copy(acc_o, Output[b_i, s_i, H0:H1, :]) + T.copy(sumexp, Lse[b_i, s_i, H0:H1]) + + return main + + +def sparse_mla_fwd_interface( + q, + kv, + indices, + sm_scale=None, + return_p_sum: bool = False, + d_v=512, + block_I=64, + num_stages=2, + threads=256, +): + q = q.unsqueeze(0) + kv = kv.unsqueeze(0) + indices = indices.unsqueeze(0) + + is_casual = True + assert return_p_sum == False, "This kernel file is for fwd only" + assert q.is_contiguous() and kv.is_contiguous() and indices.is_contiguous() + batch, seq_len, heads, dim_plus_tail_dim = q.shape + _, seq_len_kv, kv_group, _ = kv.shape + + assert dim_plus_tail_dim == 576, "you should assign dim otherwise" + dim = d_v + + assert kv.shape[-1] == dim_plus_tail_dim + tail_dim = dim_plus_tail_dim - dim + assert kv.shape[0] == batch + _, _, _, topk = indices.shape + assert indices.shape == (batch, seq_len, kv_group, topk) + + kernel = sparse_mla_fwd( + heads, + dim, + tail_dim, + topk, + kv_group, + sm_scale, + is_casual, + block_I=block_I, + num_stages=num_stages, + threads=threads, + ) + out, lse = kernel(q, kv, indices) + out = out.squeeze(0) + lse = lse.squeeze(0) + return out, lse diff --git a/areal/models/mcore/deepseek_v3.py b/areal/models/mcore/deepseek_v3.py new file mode 100644 index 0000000000..457260ff82 --- /dev/null +++ b/areal/models/mcore/deepseek_v3.py @@ -0,0 +1,354 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""DeepseekV3ForCausalLM / GLM-5.1 / GLM-4.7-Flash support for megatron-core. + +This module provides: +1. HF config -> MLATransformerConfig conversion +2. Homogeneous MLA layer spec construction +3. DSA (Dynamic Sparse Attention) support for GLM-5.1 + +DeepSeek V3 / GLM-5.1 / GLM-4.7-Flash uses: +- MLA (Multi-head Latent Attention) for all layers +- MoE: sigmoid routing, grouped TopK, shared experts +- Dense layers for first `first_k_dense_replace` layers, MoE for the rest +- YaRN RoPE scaling (optional, GLM-4.7-Flash uses plain RoPE) +- DSA indexer (GLM-5.1 only): per-layer sparse attention token selector + +Note: The MLA RoPE patch for CP>1 is applied in bailing_moe.py at module level +and automatically benefits all MLA models including DeepSeek V3. +""" + +import os + +import torch +from megatron.core.models.gpt.gpt_layer_specs import ( + get_gpt_layer_with_transformer_engine_spec, +) +from megatron.core.transformer.enums import LayerType +from megatron.core.transformer.multi_latent_attention import MLATransformerConfig +from megatron.core.transformer.transformer_block import ( + TransformerBlockSubmodules, + get_num_layers_to_build, +) +from megatron.core.transformer.transformer_layer import get_transformer_layer_offset +from transformers import PretrainedConfig + +from areal.models.mcore.common import check_and_construct_configs, hf_to_mcore_base_args +from areal.utils import logging + +logger = logging.getLogger("DeepSeekV3") + + +def _has_dsa(hf_config: PretrainedConfig) -> bool: + # DSA enabled when the HF config exposes indexer topk + n_heads. Uses a + # slime-style native DSA MLA module (see dsa_mla_attention.py) that + # inherits Attention directly instead of going through mcore's DSAttention + # container, so packed THD inputs work without modification. + return ( + getattr(hf_config, "index_topk", None) is not None + and getattr(hf_config, "index_n_heads", None) is not None + ) + + +def hf_to_mcore_config_deepseek_v3( + hf_config: PretrainedConfig, + dtype: torch.dtype, +) -> MLATransformerConfig: + """Convert DeepSeek V3 / GLM-5.1 HuggingFace config to MLATransformerConfig. + + DeepSeek V3 architecture uses MLA for all layers (no Lightning Attention), + which makes it simpler than BailingMoeV2_5. + + Args: + hf_config: HuggingFace PretrainedConfig for DeepseekV3ForCausalLM + dtype: Data type for the model parameters + + Returns: + MLATransformerConfig with MLA + MoE parameters + """ + # MTP layers are not used during RL training (only for SGLang EAGLE-style + # inference). Setting AREAL_DISABLE_MTP=1 zeroes out + # num_nextn_predict_layers so no MTP module is built, preventing rare bwd + # NaN paths through the MTP block. + if os.environ.get("AREAL_DISABLE_MTP", "0") == "1": + if getattr(hf_config, "num_nextn_predict_layers", 0): + logger.warning( + f"AREAL_DISABLE_MTP=1: overriding " + f"hf_config.num_nextn_predict_layers " + f"from {hf_config.num_nextn_predict_layers} to 0" + ) + hf_config.num_nextn_predict_layers = 0 + + # Build moe_layer_freq: 0 for dense, 1 for MoE + num_layers = hf_config.num_hidden_layers + first_k_dense_replace = getattr(hf_config, "first_k_dense_replace", 3) + moe_layer_freq = [0 if i < first_k_dense_replace else 1 for i in range(num_layers)] + + # Shared expert intermediate size + n_shared_experts = getattr(hf_config, "n_shared_experts", 0) + moe_intermediate_size = getattr( + hf_config, "moe_intermediate_size", hf_config.intermediate_size + ) + shared_expert_intermediate_size = ( + n_shared_experts * moe_intermediate_size if n_shared_experts > 0 else None + ) + + # Get base args common to all models + base_args = hf_to_mcore_base_args( + hf_config=hf_config, + dtype=dtype, + use_cpu_initialization=False, + add_bias_linear=False, + add_qkv_bias=False, + qk_layernorm=True, + ) + + # MLA-specific parameters + # + # DeepSeek V3 uses YaRN RoPE scaling. The rotary_scaling_factor, mscale, + # and mscale_all_dim must be set correctly from the HF config's rope_scaling. + # + # DeepSeek V3 HF config rope_scaling example: + # {"type": "yarn", "factor": 4.0, "mscale": 0.707, "mscale_all_dim": 0.707, ...} + rope_scaling = getattr(hf_config, "rope_scaling", None) or {} + rotary_scaling_factor = rope_scaling.get("factor", 1.0) + + # rope_theta: top-level field or inside rope_parameters (GLM-5.1) + rope_theta = getattr(hf_config, "rope_theta", None) + if rope_theta is None: + rope_params = getattr(hf_config, "rope_parameters", None) or {} + rope_theta = rope_params.get("rope_theta", 10000.0) + + mla_args = { + "multi_latent_attention": True, + "q_lora_rank": getattr(hf_config, "q_lora_rank", None), + "kv_lora_rank": getattr(hf_config, "kv_lora_rank", 512), + "qk_head_dim": getattr(hf_config, "qk_nope_head_dim", 128), + "qk_pos_emb_head_dim": getattr(hf_config, "qk_rope_head_dim", 64), + "v_head_dim": getattr(hf_config, "v_head_dim", 128), + # RoPE + "rope_type": "rope", + "rotary_base": rope_theta, + "rotary_percent": getattr(hf_config, "partial_rotary_factor", 1.0), + "rotary_scaling_factor": rotary_scaling_factor, + "apply_rope_fusion": False, + } + if rope_scaling.get("type") == "yarn" or rope_scaling.get("rope_type") == "yarn": + mla_args["mscale"] = rope_scaling.get("mscale", 0.707) + mla_args["mscale_all_dim"] = rope_scaling.get("mscale_all_dim", 0.707) + + # MoE-specific parameters + n_routed_experts = getattr(hf_config, "n_routed_experts", None) + if n_routed_experts is None: + n_routed_experts = getattr(hf_config, "num_local_experts", None) + + moe_args = { + "num_moe_experts": n_routed_experts, + "moe_router_topk": getattr(hf_config, "num_experts_per_tok", 8), + "moe_router_score_function": getattr(hf_config, "scoring_func", "sigmoid"), + "moe_router_num_groups": getattr(hf_config, "n_group", 8), + "moe_router_group_topk": getattr(hf_config, "topk_group", 4), + "moe_router_topk_scaling_factor": getattr( + hf_config, "routed_scaling_factor", None + ), + "moe_ffn_hidden_size": moe_intermediate_size, + "moe_shared_expert_intermediate_size": shared_expert_intermediate_size, + "moe_layer_freq": moe_layer_freq, + "moe_router_enable_expert_bias": True, + "moe_router_load_balancing_type": "none", + "moe_grouped_gemm": True, + "moe_router_dtype": "fp32", + "moe_router_bias_update_rate": 0.0, + "moe_z_loss_coeff": 3.5e-6, + "moe_enable_routing_replay": bool(os.environ.get("AREAL_DUMP_ROUTING", "")), + } + if moe_args["moe_enable_routing_replay"]: + logger.info("AREAL_DUMP_ROUTING is set; moe_enable_routing_replay=True") + + # Numerical stability flags: + # bf16 attention softmax + MoE forward output can amplify numeric range, + # breaking bf16 linear/RMSNorm backward on long-context training. + # attention_softmax_in_fp32 is the most impactful stability knob. + # check_and_construct_configs (common.py) silently drops keys not on + # MLATransformerConfig, so older mcore versions still work. + stability_args = { + "attention_softmax_in_fp32": True, + "cross_entropy_loss_fusion": False, # use fp32 unfused cross-entropy + "disable_bf16_reduced_precision_matmul": True, + } + + # Merge all args + all_args = {**base_args, **mla_args, **moe_args, **stability_args} + + # DSA (Dynamic Sparse Attention) parameters for GLM-5.1 + if _has_dsa(hf_config): + dsa_indexer_loss_coeff = getattr(hf_config, "dsa_indexer_loss_coeff", 0.0) + # NOTE: do NOT set experimental_attention_variant="dsa" — that triggers + # mcore's own DSA code paths inside multi_latent_attention.py, which we + # bypass by providing a slime-style custom self_attention module spec. + dsa_args = { + "dsa_indexer_n_heads": hf_config.index_n_heads, + "dsa_indexer_head_dim": hf_config.index_head_dim, + "dsa_indexer_topk": hf_config.index_topk, + "dsa_indexer_loss_coeff": dsa_indexer_loss_coeff, + "dsa_indexer_use_sparse_loss": getattr( + hf_config, "dsa_indexer_use_sparse_loss", False + ), + } + all_args.update(dsa_args) + logger.info( + f"DSA enabled: index_n_heads={hf_config.index_n_heads}, " + f"index_head_dim={hf_config.index_head_dim}, " + f"index_topk={hf_config.index_topk}, " + f"indexer_loss_coeff={dsa_indexer_loss_coeff}" + ) + + return check_and_construct_configs(all_args, MLATransformerConfig) + + +def make_mcore_layer_specs_deepseek_v3( + tf_config: MLATransformerConfig, + hf_config: PretrainedConfig, + use_te: bool = True, + vp_stage: int | None = None, +) -> TransformerBlockSubmodules: + """Build homogeneous MLA layer specs for DeepSeek V3 / GLM-5.1. + + All layers use MLA attention. The only variation is Dense MLP vs MoE MLP, + determined by `first_k_dense_replace`. + + Args: + tf_config: MLATransformerConfig with all model parameters + hf_config: HF config for first_k_dense_replace + use_te: Whether to use Transformer Engine modules + vp_stage: Virtual pipeline stage (for VPP support) + + Returns: + TransformerBlockSubmodules with MLA layer specs (PP-sliced if PP>1) + """ + assert tf_config.normalization == "RMSNorm", "only RMSNorm is supported" + + num_layers = tf_config.num_layers + first_k_dense_replace = getattr(hf_config, "first_k_dense_replace", 3) + use_dsa = _has_dsa(hf_config) + + if use_dsa: + # Build a slime-style DSA self-attention spec that wraps our custom + # DSAMLASelfAttention (inherits Attention directly, not DSAttention). + from megatron.core.extensions.transformer_engine import ( + TEDotProductAttention, + TELayerNormColumnParallelLinear, + TELinear, + TENorm, + TERowParallelLinear, + ) + from megatron.core.transformer.enums import AttnMaskType + from megatron.core.transformer.identity_op import IdentityOp + from megatron.core.transformer.spec_utils import ModuleSpec + + from areal.models.mcore.dsa_mla_attention import ( + DSAMLASelfAttention, + DSASelfAttentionSubmodules, + ) + + dsa_attention_spec = ModuleSpec( + module=DSAMLASelfAttention, + params={"attn_mask_type": AttnMaskType.causal}, + submodules=DSASelfAttentionSubmodules( + linear_q_down_proj=TELinear, + linear_q_up_proj=TELayerNormColumnParallelLinear, + linear_kv_down_proj=TELinear, + linear_kv_up_proj=TELayerNormColumnParallelLinear, + linear_v_up_proj=IdentityOp, + core_attention=TEDotProductAttention, + linear_proj=TERowParallelLinear, + q_layernorm=IdentityOp, + kv_layernorm=IdentityOp, + wq_b=TELinear, + wk=TELinear, + k_norm=TENorm, + weights_proj=TELinear, + ), + ) + + # Build MLA layer specs (all layers use MLA, optionally with DSA) + def _make_layer_spec(num_experts, moe_grouped_gemm): + base_spec = get_gpt_layer_with_transformer_engine_spec( + num_experts=num_experts, + moe_grouped_gemm=moe_grouped_gemm, + qk_layernorm=tf_config.qk_layernorm, + multi_latent_attention=True, + ) + if use_dsa: + base_spec.submodules.self_attention = dsa_attention_spec + return base_spec + + mla_dense_spec = _make_layer_spec(num_experts=None, moe_grouped_gemm=False) + mla_moe_spec = _make_layer_spec( + num_experts=tf_config.num_moe_experts, + moe_grouped_gemm=tf_config.moe_grouped_gemm, + ) + + # Build per-layer specs + layer_specs = [] + for layer_idx in range(num_layers): + is_moe = layer_idx >= first_k_dense_replace + spec = mla_moe_spec if is_moe else mla_dense_spec + layer_specs.append(spec) + + n_moe = sum(1 for i in range(num_layers) if i >= first_k_dense_replace) + n_dense = num_layers - n_moe + attn_type = "MLA+DSA" if use_dsa else "MLA" + logger.info( + f"Built DeepSeek V3 layer specs: {num_layers} layers (all {attn_type}), " + f"first_k_dense={first_k_dense_replace}, " + f"num_experts={tf_config.num_moe_experts}" + ) + logger.info(f"Layer composition: {n_dense} Dense + {n_moe} MoE") + + # PP slicing: when PP>1, only include layers for the current pipeline stage. + num_layers_to_build = get_num_layers_to_build(tf_config, vp_stage=vp_stage) + + if tf_config.pipeline_model_parallel_layout is not None: + local_layer_specs = [ + layer_specs[layer_id] + for layer_id in tf_config.pipeline_model_parallel_layout.get_layer_id_list( + layer_type=LayerType.decoder, vp_stage=vp_stage + ) + ] + elif num_layers_to_build < num_layers: + offset = get_transformer_layer_offset(tf_config, vp_stage=vp_stage) + local_layer_specs = layer_specs[offset : offset + num_layers_to_build] + else: + local_layer_specs = layer_specs + + if len(local_layer_specs) != num_layers: + logger.info( + f"PP slicing: building {len(local_layer_specs)}/{num_layers} layers " + f"for this pipeline stage" + ) + + # Get layer norm implementation + if use_te: + try: + from megatron.core.extensions.transformer_engine import TENorm + + layer_norm_impl = TENorm + except ImportError: + from megatron.core.transformer.torch_norm import WrappedTorchNorm + + layer_norm_impl = WrappedTorchNorm + else: + try: + from megatron.core.fusions.fused_layer_norm import FusedLayerNorm + + layer_norm_impl = FusedLayerNorm + except ImportError: + from megatron.core.transformer.torch_norm import WrappedTorchNorm + + layer_norm_impl = WrappedTorchNorm + + return TransformerBlockSubmodules( + layer_specs=local_layer_specs, + layer_norm=layer_norm_impl, + ) diff --git a/areal/models/mcore/deepseek_v3_bridge.py b/areal/models/mcore/deepseek_v3_bridge.py new file mode 100644 index 0000000000..faea102a9e --- /dev/null +++ b/areal/models/mcore/deepseek_v3_bridge.py @@ -0,0 +1,378 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""mbridge Bridge for DeepSeek V3 / GLM-5.1 / GLM-4.7-Flash. + +Registers with mbridge so that MegatronEngine.initialize() can use AutoBridge +to load and manage DeepSeek V3 / GLM-5.1 / GLM-4.7-Flash models with +homogeneous MLA attention and MoE layers. + +Key differences from BailingMoeBridge: +- All layers use MLA (no Lightning Attention heterogeneity) +- q_lora_rank is always non-None (Q uses low-rank decomposition) +- No fused QKV weight conversion needed +- HF uses 'self_attn' prefix (not 'attention') and 'o_proj' (not 'dense') +- HF embedding key is 'model.embed_tokens.weight' (not 'model.word_embeddings.weight') +- MoE expert count field is 'n_routed_experts' (not 'num_experts') +- scoring_func defaults to sigmoid +- GLM-5.1 adds DSA indexer weights (wq_b, wk, k_norm, weights_proj) + +Note: GLM-4.7-Flash has num_nextn_predict_layers=1 with extra weights at +layer index num_hidden_layers. These weights are automatically ignored since +megatron-core only builds num_hidden_layers transformer layers. +""" + +import os + +import torch +from mbridge.core import LLMBridge, register_model +from megatron.core.transformer import MLATransformerConfig +from megatron.core.transformer.enums import AttnBackend + +from areal.models.mcore.deepseek_v3 import make_mcore_layer_specs_deepseek_v3 +from areal.utils import logging + +logger = logging.getLogger("DeepSeekV3Bridge") + +# MLA Q-LoRA mapping (mcore suffix -> HF name templates) +# DeepSeek V3 / GLM-5.1 always uses Q-LoRA (q_lora_rank=1536) +_MLA_Q_LORA_MAPPING = { + "self_attention.linear_q_down_proj.weight": [ + "model.layers.{layer_number}.self_attn.q_a_proj.weight" + ], + "self_attention.linear_q_up_proj.layer_norm_weight": [ + "model.layers.{layer_number}.self_attn.q_a_layernorm.weight" + ], + "self_attention.linear_q_up_proj.weight": [ + "model.layers.{layer_number}.self_attn.q_b_proj.weight" + ], +} + +# MLA KV compression + output projection mapping +_MLA_COMMON_MAPPING = { + "input_layernorm.weight": ["model.layers.{layer_number}.input_layernorm.weight"], + "self_attention.linear_kv_down_proj.weight": [ + "model.layers.{layer_number}.self_attn.kv_a_proj_with_mqa.weight" + ], + "self_attention.linear_kv_up_proj.layer_norm_weight": [ + "model.layers.{layer_number}.self_attn.kv_a_layernorm.weight" + ], + "self_attention.linear_kv_up_proj.weight": [ + "model.layers.{layer_number}.self_attn.kv_b_proj.weight" + ], + "self_attention.linear_proj.weight": [ + "model.layers.{layer_number}.self_attn.o_proj.weight" + ], +} + +# Combined MLA attention mapping (always Q-LoRA for DeepSeek V3) +_MLA_ATTENTION_MAPPING = {**_MLA_COMMON_MAPPING, **_MLA_Q_LORA_MAPPING} + +# DSA indexer weight mapping (GLM-5.1 only) +# slime-style DSAMLASelfAttention attaches indexer submodules directly on the +# self-attention module (NOT inside a core_attention.indexer container), using +# bare names wq_b / wk / k_norm / weights_proj. LayerNorm k_norm has weight+bias. +_DSA_INDEXER_MAPPING = { + "self_attention.wq_b.weight": [ + "model.layers.{layer_number}.self_attn.indexer.wq_b.weight" + ], + "self_attention.wk.weight": [ + "model.layers.{layer_number}.self_attn.indexer.wk.weight" + ], + "self_attention.k_norm.weight": [ + "model.layers.{layer_number}.self_attn.indexer.k_norm.weight" + ], + "self_attention.k_norm.bias": [ + "model.layers.{layer_number}.self_attn.indexer.k_norm.bias" + ], + "self_attention.weights_proj.weight": [ + "model.layers.{layer_number}.self_attn.indexer.weights_proj.weight" + ], +} + + +@register_model("deepseek_v3") +@register_model("glm_moe_dsa") +@register_model("glm4_moe_lite") +class DeepSeekV3Bridge(LLMBridge): + """Bridge for DeepSeek V3 / GLM-5.1 with homogeneous MLA + MoE.""" + + TransformerConfigClass = MLATransformerConfig + + @property + def _has_dsa_indexer(self) -> bool: + # slime-style native DSA module: wq_b / wk / k_norm / weights_proj are + # attached directly on DSAMLASelfAttention (see dsa_mla_attention.py), + # NOT inside a DSAttention container. Enabled when HF config exposes + # index_topk + index_n_heads. + return ( + getattr(self.hf_config, "index_topk", None) is not None + and getattr(self.hf_config, "index_n_heads", None) is not None + ) + + _DIRECT_MAPPING = { + "embedding.word_embeddings.weight": "model.embed_tokens.weight", + "decoder.final_layernorm.weight": "model.norm.weight", + "output_layer.weight": "lm_head.weight", + } + + _MLP_MAPPING = { + # Dense MLP (layers < first_k_dense_replace) + "mlp.linear_fc1.layer_norm_weight": [ + "model.layers.{layer_number}.post_attention_layernorm.weight" + ], + "mlp.linear_fc2.weight": ["model.layers.{layer_number}.mlp.down_proj.weight"], + "mlp.linear_fc1.weight": [ + "model.layers.{layer_number}.mlp.gate_proj.weight", + "model.layers.{layer_number}.mlp.up_proj.weight", + ], + # MoE shared experts + "mlp.shared_experts.linear_fc2.weight": [ + "model.layers.{layer_number}.mlp.shared_experts.down_proj.weight" + ], + "mlp.shared_experts.linear_fc1.weight": [ + "model.layers.{layer_number}.mlp.shared_experts.gate_proj.weight", + "model.layers.{layer_number}.mlp.shared_experts.up_proj.weight", + ], + # MoE pre-MLP layernorm + "pre_mlp_layernorm.weight": [ + "model.layers.{layer_number}.post_attention_layernorm.weight" + ], + # MoE router + "mlp.router.weight": ["model.layers.{layer_number}.mlp.gate.weight"], + "mlp.router.expert_bias": [ + "model.layers.{layer_number}.mlp.gate.e_score_correction_bias" + ], + # MoE experts + "mlp.experts.linear_fc1.weight": [ + "model.layers.{layer_number}.mlp.experts.{expert_id}.gate_proj.weight", + "model.layers.{layer_number}.mlp.experts.{expert_id}.up_proj.weight", + ], + "mlp.experts.linear_fc2.weight": [ + "model.layers.{layer_number}.mlp.experts.{expert_id}.down_proj.weight" + ], + } + + def _build_config(self): + hf_config = self.hf_config + + # Build moe_layer_freq: 0 for dense, 1 for MoE + num_layers = hf_config.num_hidden_layers + first_k_dense_replace = getattr(hf_config, "first_k_dense_replace", 3) + moe_layer_freq = [ + 0 if i < first_k_dense_replace else 1 for i in range(num_layers) + ] + + # Shared expert intermediate size + n_shared_experts = getattr(hf_config, "n_shared_experts", 0) + moe_intermediate_size = getattr( + hf_config, "moe_intermediate_size", hf_config.intermediate_size + ) + shared_expert_intermediate_size = ( + n_shared_experts * moe_intermediate_size if n_shared_experts > 0 else None + ) + + # Number of routed experts + n_routed_experts = getattr(hf_config, "n_routed_experts", None) + if n_routed_experts is None: + n_routed_experts = getattr(hf_config, "num_local_experts", None) + + # YaRN RoPE scaling parameters + # G20 (2026-05-08): only pass mscale/mscale_all_dim when rope_scaling + # actually requests YaRN (rope_type=='yarn'). GLM-5.1 has + # rope_scaling={'rope_theta':..., 'rope_type':'default'} which is a + # truthy dict but NOT YaRN — previously we set mscale=0.707 default, + # which 0.5x'd softmax_scale and caused 7x SFT loss vs sglang. + rope_scaling_dict = getattr(hf_config, "rope_scaling", None) or {} + rotary_scaling_factor = rope_scaling_dict.get("factor", 1.0) + + # rope_theta: top-level field or inside rope_parameters (GLM-5.1) + rope_theta = getattr(hf_config, "rope_theta", None) + if rope_theta is None: + rope_params = getattr(hf_config, "rope_parameters", None) or {} + rope_theta = rope_params.get("rope_theta", 10000.0) + + mscale_kwargs = {} + if ( + rope_scaling_dict.get("rope_type") == "yarn" + or rope_scaling_dict.get("type") == "yarn" + ): + mscale_kwargs["mscale"] = rope_scaling_dict.get("mscale", 0.707) + mscale_kwargs["mscale_all_dim"] = rope_scaling_dict.get( + "mscale_all_dim", 0.707 + ) + + return self._build_base_config( + attention_backend=AttnBackend.fused, + layernorm_epsilon=hf_config.rms_norm_eps, + ffn_hidden_size=hf_config.intermediate_size, + qk_layernorm=True, + # MLA parameters + multi_latent_attention=True, + q_lora_rank=getattr(hf_config, "q_lora_rank", None), + kv_lora_rank=getattr(hf_config, "kv_lora_rank", 512), + qk_head_dim=getattr(hf_config, "qk_nope_head_dim", 128), + qk_pos_emb_head_dim=getattr(hf_config, "qk_rope_head_dim", 64), + v_head_dim=getattr(hf_config, "v_head_dim", 128), + rotary_base=rope_theta, + rope_type="rope", + rotary_percent=getattr(hf_config, "partial_rotary_factor", 1.0), + rotary_scaling_factor=rotary_scaling_factor, + apply_rope_fusion=False, + **mscale_kwargs, + # MoE parameters + moe_ffn_hidden_size=moe_intermediate_size, + moe_token_dispatcher_type="alltoall", + moe_router_enable_expert_bias=True, + moe_router_topk=getattr(hf_config, "num_experts_per_tok", 8), + num_moe_experts=n_routed_experts, + moe_shared_expert_intermediate_size=shared_expert_intermediate_size, + moe_router_score_function="sigmoid", + moe_router_num_groups=getattr(hf_config, "n_group", 8), + moe_router_group_topk=getattr(hf_config, "topk_group", 4), + moe_router_topk_scaling_factor=getattr( + hf_config, "routed_scaling_factor", None + ), + moe_router_load_balancing_type="none", + moe_grouped_gemm=True, + moe_layer_freq=moe_layer_freq, + moe_router_dtype="fp32", + moe_router_bias_update_rate=0.0, + moe_z_loss_coeff=3.5e-6, + moe_enable_routing_replay=bool(os.environ.get("AREAL_DUMP_ROUTING", "")), + # Other + persist_layer_norm=True, + bias_activation_fusion=True, + bias_dropout_fusion=True, + # DSA parameters (GLM-5.1). + # NOTE: do NOT set experimental_attention_variant="dsa" — we use a + # slime-style custom self_attention spec (DSAMLASelfAttention) and + # bypass mcore's own DSA paths. + **( + { + "dsa_indexer_n_heads": hf_config.index_n_heads, + "dsa_indexer_head_dim": hf_config.index_head_dim, + "dsa_indexer_topk": hf_config.index_topk, + "dsa_indexer_loss_coeff": getattr( + hf_config, "dsa_indexer_loss_coeff", 0.0 + ), + "dsa_indexer_use_sparse_loss": getattr( + hf_config, "dsa_indexer_use_sparse_loss", False + ), + } + if self._has_dsa_indexer + else {} + ), + ) + + def _get_gptmodel_args(self) -> dict: + rope_theta = getattr(self.hf_config, "rope_theta", None) + if rope_theta is None: + rope_params = getattr(self.hf_config, "rope_parameters", None) or {} + rope_theta = rope_params.get("rope_theta", 10000.0) + return dict( + vocab_size=self.hf_config.vocab_size, + max_sequence_length=self.hf_config.max_position_embeddings, + position_embedding_type="rope", + rotary_base=rope_theta, + ) + + def _get_transformer_layer_spec(self, vp_stage: int | None = None): + """Return homogeneous MLA layer specs (all layers use MLA). + + PP slicing is handled inside make_mcore_layer_specs_deepseek_v3. + """ + assert self.config.normalization == "RMSNorm" + self.has_vp_stage = vp_stage is not None + return make_mcore_layer_specs_deepseek_v3( + self.config, self.hf_config, use_te=True, vp_stage=vp_stage + ) + + def _weight_name_mapping_mcore_to_hf(self, mcore_weights_name: str) -> list[str]: + assert "_extra_state" not in mcore_weights_name + + if mcore_weights_name in self._DIRECT_MAPPING: + return [self._DIRECT_MAPPING[mcore_weights_name]] + + if ( + "self_attention" in mcore_weights_name + or "input_layernorm.weight" in mcore_weights_name + ): + return self._weight_name_mapping_attention(mcore_weights_name) + elif "mlp" in mcore_weights_name or "pre_mlp_layernorm" in mcore_weights_name: + return self._weight_name_mapping_mlp(mcore_weights_name) + else: + raise NotImplementedError( + f"Unsupported parameter name: {mcore_weights_name}" + ) + + def _weight_merge_across_tp( + self, + mcore_weights_name: str, + tp_shards: list[torch.Tensor], + param: torch.Tensor, + ) -> torch.Tensor: + """Handle MLA and DSA duplicated weights. + + linear_q_down_proj and linear_kv_down_proj use parallel_mode='duplicated' + in megatron-core MLA — they are replicated (not sharded) across TP ranks. + DSA indexer weights (wq_b, wk, k_norm, weights_proj) live directly under + self_attention in our slime-style DSAMLASelfAttention and are also + duplicated via parallel_mode='duplicated'. + All shards are identical, so just return the first one. + """ + if ( + "linear_q_down_proj." in mcore_weights_name + or "linear_kv_down_proj." in mcore_weights_name + or "self_attention.wq_b." in mcore_weights_name + or "self_attention.wk." in mcore_weights_name + or "self_attention.k_norm." in mcore_weights_name + or "self_attention.weights_proj." in mcore_weights_name + ): + return tp_shards[0].clone() + return super()._weight_merge_across_tp(mcore_weights_name, tp_shards, param) + + def _weight_name_mapping_attention(self, name: str) -> list[str]: + """Map MLA attention weights. All layers use MLA (no heterogeneous dispatch). + + For GLM-5.1, also handles DSA indexer weights. + """ + layer_number_str = name.split(".")[2] + + # Check DSA indexer mappings first + mapping = _MLA_ATTENTION_MAPPING + if self._has_dsa_indexer: + mapping = {**mapping, **_DSA_INDEXER_MAPPING} + + convert_names = [] + for keyword, mapping_names in mapping.items(): + if keyword in name: + convert_names.extend( + [x.format(layer_number=layer_number_str) for x in mapping_names] + ) + break + + if not convert_names: + raise NotImplementedError(f"Unsupported attention parameter: {name}") + return convert_names + + def _weight_name_mapping_mlp(self, name: str) -> list[str]: + layer_number = name.split(".")[2] + convert_names = [] + for keyword, mapping_names in self._MLP_MAPPING.items(): + if keyword in name: + if "{expert_id}" in mapping_names[0]: + expert_id = name.split("weight")[-1] + convert_names.extend( + [ + x.format(layer_number=layer_number, expert_id=expert_id) + for x in mapping_names + ] + ) + else: + convert_names.extend( + [x.format(layer_number=layer_number) for x in mapping_names] + ) + break + if not convert_names: + raise NotImplementedError(f"Unsupported MLP parameter: {name}") + return convert_names diff --git a/areal/models/mcore/dsa_mla_attention.py b/areal/models/mcore/dsa_mla_attention.py new file mode 100644 index 0000000000..d07a741cc3 --- /dev/null +++ b/areal/models/mcore/dsa_mla_attention.py @@ -0,0 +1,699 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""DSA (Deep Sparse Attention) MLA self-attention for GLM-5.1. + +Ported from slime_plugins/models/glm5/glm5.py (L33-L604) with these adaptations +for AReaL: + +* Drop modelopt-based `Linear` path — only TE/standard linears supported here. +* Drop `backward_dw` / `set_for_recompute_input_layernorm` — AReaL does not + split weight gradient updates across micro-batch boundaries. +* Use TE's `fused_apply_rotary_pos_emb_thd` (slime notes precision is slightly + worse than apex's, but apex.transformer is unavailable in our container). +* `index_topk` falls back to 2048 if `config.dsa_indexer_topk` is absent. + +This module inherits `megatron.core.transformer.attention.Attention` directly +(NOT `mcore` DSAttention container), so packed THD inputs flow naturally +without the `assert packed_seq_params is None` that mcore's DSAttention raises. +""" + +import math +from dataclasses import dataclass +from typing import NoReturn + +import torch +from megatron.core import parallel_state +from megatron.core.extensions.transformer_engine import ( + TEColumnParallelLinear, + TELinear, + fused_apply_rotary_pos_emb_thd, +) +from megatron.core.models.common.embeddings import ( + RotaryEmbedding, + YarnRotaryEmbedding, + _yarn_get_mscale, +) +from megatron.core.tensor_parallel.layers import ColumnParallelLinear +from megatron.core.tensor_parallel.mappings import ( + gather_from_sequence_parallel_region, +) +from megatron.core.transformer.attention import Attention +from megatron.core.transformer.enums import AttnMaskType +from megatron.core.transformer.moe.moe_utils import ( + RouterGatingLinearFunction as WeightLinearFunction, +) +from megatron.core.transformer.spec_utils import ModuleSpec, build_module +from megatron.core.transformer.transformer_config import MLATransformerConfig + +from areal.experimental.ops.dsa.indexer import ( + generate_varlen_mask_params, + lighting_indexer, +) +from areal.experimental.ops.dsa.sparse_mla import SparseMLA + + +@dataclass +class DSASelfAttentionSubmodules: + """Submodules for the DSA MLA self-attention layer.""" + + linear_q_down_proj: ModuleSpec | type = None + linear_q_up_proj: ModuleSpec | type = None + linear_kv_down_proj: ModuleSpec | type = None + linear_kv_up_proj: ModuleSpec | type = None + linear_v_up_proj: ModuleSpec | type = None + core_attention: ModuleSpec | type = None + linear_proj: ModuleSpec | type = None + q_layernorm: ModuleSpec | type = None + kv_layernorm: ModuleSpec | type = None + # added for indexer + wq_b: ModuleSpec | type = None + wk: ModuleSpec | type = None + k_norm: ModuleSpec | type = None + weights_proj: ModuleSpec | type = None + + +class DSAMultiLatentAttention(Attention): + """DSA-enabled Multi-Latent Attention base class. + + Holds the shared init (output proj, rotary embedding, softmax scale) and + the forward path that composes indexer + SparseMLA kernel. Self-attention + specialization (q/kv down/up projections, indexer submodules) lives on + `DSAMLASelfAttention`. + """ + + def __init__( + self, + config: MLATransformerConfig, + submodules: DSASelfAttentionSubmodules, + layer_number: int, + attn_mask_type: AttnMaskType, + attention_type: str, + is_mtp_layer: bool = False, + cp_comm_type: str | None = None, + model_comm_pgs=None, + pg_collection=None, + ) -> None: + super().__init__( + config=config, + submodules=submodules, + layer_number=layer_number, + attention_type=attention_type, + attn_mask_type=attn_mask_type, + cp_comm_type=cp_comm_type, + pg_collection=pg_collection, + ) + self.query_projection_size = ( + self.config.v_head_dim * self.config.num_attention_heads + ) + self.q_head_dim = self.config.qk_head_dim + self.config.qk_pos_emb_head_dim + + # Overwrite base class kv shape for MLA inference compatibility. + self.key_hidden_size = self.q_head_dim + self.val_hidden_size = self.config.v_head_dim + + self.recompute_up_proj = ( + self.config.recompute_granularity == "selective" + and "mla_up_proj" in self.config.recompute_modules + ) + self.qkv_up_checkpoint = None + + mscale = _yarn_get_mscale(self.config.rotary_scaling_factor, self.config.mscale) + self.softmax_scale = mscale * mscale / math.sqrt(self.q_head_dim) + + if self.config.rope_type == "rope": + self.rotary_pos_emb = RotaryEmbedding( + self.config.qk_pos_emb_head_dim, + rotary_percent=self.config.rotary_percent, + rotary_base=self.config.rotary_base, + cp_group=self.pg_collection.cp, + ) + elif self.config.rope_type == "yarn": + self.rotary_pos_emb = YarnRotaryEmbedding( + self.config.qk_pos_emb_head_dim, + rotary_base=self.config.rotary_base, + scaling_factor=self.config.rotary_scaling_factor, + original_max_position_embeddings=self.config.original_max_position_embeddings, + beta_fast=self.config.beta_fast, + beta_slow=self.config.beta_slow, + mscale=self.config.mscale, + mscale_all_dim=self.config.mscale_all_dim, + cp_group=self.pg_collection.cp, + ) + else: + raise ValueError( + f"Unsupported RoPE type: {self.config.rope_type}; " + "supported types are 'rope' and 'yarn'" + ) + + # Output projection. + self.linear_proj = build_module( + submodules.linear_proj, + self.query_projection_size, + self.config.hidden_size, + config=self.config, + init_method=self.config.output_layer_init_method, + bias=self.config.add_bias_linear, + input_is_parallel=True, + skip_bias_add=True, + is_expert=False, + tp_comm_buffer_name="proj", + tp_group=self.pg_collection.tp, + ) + + self.index_topk = getattr(self.config, "dsa_indexer_topk", None) or 2048 + + def forward( + self, + hidden_states, + attention_mask, + key_value_states=None, + inference_context=None, + rotary_pos_emb=None, + rotary_pos_cos=None, + rotary_pos_sin=None, + rotary_pos_cos_sin=None, + attention_bias=None, + packed_seq_params=None, + position_ids=None, + sequence_len_offset=None, + *, + inference_params=None, + router_token_masks=None, + loss_mask=None, + ): + """Forward pass for DSA multi-latent attention.""" + assert rotary_pos_emb is None, "Rotary pos emb should not be passed into MLA." + assert attention_bias is None, "Attention bias should not be passed into MLA." + assert rotary_pos_cos is None and rotary_pos_sin is None, ( + "MLA does not support Flash Decoding" + ) + + q, kv, wv, index_query, index_key, head_weights = ( + self.get_absorb_query_key_value_tensors( + hidden_states, + key_value_states, + position_ids, + packed_seq_params, + inference_context=inference_context, + ) + ) + + def fused_select_topk(index_q, index_k, w, starts, ends, block_size=8192): + seq_len = index_q.shape[0] + # Clip topk to available key length. TileLang indexer_bwd kernel + # requires topk to be a power of 2 (assert in + # tilelang_indexer_bwd.py:27). key_len may not be power-of-2 + # (e.g. 9728 under CP=2), so clip to largest 2^n that still fits. + raw_cap = min(self.index_topk, int(index_k.shape[0])) + effective_topk = 1 << max(0, (raw_cap).bit_length() - 1) + if effective_topk > raw_cap: + effective_topk >>= 1 + indexer_topk_scores = [] + topk_indices = [] + for start in range(0, seq_len, block_size): + end = min(start + block_size, seq_len) + index_q_block = index_q[start:end] + w_block = w[start:end] + starts_block = starts[start:end] + ends_block = ends[start:end] + scores_block, indices_block = lighting_indexer( + index_q_block, + index_k, + w_block, + starts_block.to(torch.int32), + ends_block.to(torch.int32), + effective_topk, + topk_indices=None, + ) + scores_block = torch.softmax(scores_block, dim=-1) + indexer_topk_scores.append(scores_block) + topk_indices.append(indices_block) + return ( + torch.cat(indexer_topk_scores, dim=0), + torch.cat(topk_indices, dim=0).unsqueeze(1), + ) + + index_key = index_key.squeeze(1) + head_weights = head_weights.unsqueeze(-1) + + cp_size = parallel_state.get_context_parallel_world_size() + + # R18 diagnostic: AREAL_DSA_FORCE_CP1=1 bypasses CP zigzag logic, + # using simple causal starts/ends even under CP>1. If NLL drops to + # baseline, the bug is in the zigzag remap above. + import os as _dsa_os + + _force_cp1 = _dsa_os.environ.get("AREAL_DSA_FORCE_CP1", "") == "1" + + if cp_size > 1 and not _force_cp1: + # index_key is cp-gathered in zigzag-interleaved layout: + # [rank0_front; rank0_back; rank1_front; rank1_back; ...] + # = [chunk0; chunk_{2C-1}; chunk1; chunk_{2C-2}; ...] + # The indexer's clean_logits kernel enforces causal mask via a + # contiguous [starts, ends) range. In zigzag layout the causally + # valid KV set is non-contiguous, so we unzigzag index_key to + # sequential layout for the indexer. Safe because index_key is + # detached — no gradient flows through it. + from areal.models.mcore.lightning_attention import ( + _build_zigzag_undo_indices, + ) + + total_len = index_key.shape[0] + undo_idx = _build_zigzag_undo_indices( + total_len, cp_size, packed_seq_params.cu_seqlens_q, index_key.device + ) + index_key_seq = index_key[undo_idx] + + # Map each local query to its real sequential position. + cp_rank = parallel_state.get_context_parallel_rank() + local_len = total_len // cp_size + gathered_pos = torch.arange( + cp_rank * local_len, + (cp_rank + 1) * local_len, + device=index_key.device, + ) + redo_idx = torch.empty(total_len, dtype=torch.long, device=index_key.device) + redo_idx[undo_idx] = torch.arange(total_len, device=index_key.device) + real_pos = redo_idx[gathered_pos] + + cu_sq = packed_seq_params.cu_seqlens_q + seq_ids = torch.searchsorted(cu_sq, real_pos, right=True) - 1 + starts = cu_sq[seq_ids] + ends = real_pos + 1 + + indexer_topk_scores, topk_indices = fused_select_topk( + index_query, index_key_seq, head_weights, starts, ends + ) + + # Remap topk_indices from sequential back to zigzag space for + # SparseMLA (which operates on zigzag-layout kv). + ti = topk_indices.squeeze(1) + valid = ti != -1 + remapped = undo_idx[ti.long().clamp(min=0)] + remapped[~valid] = -1 + topk_indices = remapped.to(topk_indices.dtype).unsqueeze(1) + else: + starts, ends = generate_varlen_mask_params(packed_seq_params.cu_seqlens_q) + indexer_topk_scores, topk_indices = fused_select_topk( + index_query, index_key, head_weights, starts, ends + ) + + core_attn_out, _ = SparseMLA.apply(q, kv, topk_indices, self.softmax_scale) + core_attn_out = torch.einsum("thm,hdm->thd", core_attn_out, wv) + core_attn_out = core_attn_out.reshape(core_attn_out.size(0), 1, -1) + + if self.recompute_up_proj: + assert self.qkv_up_checkpoint is not None + self.qkv_up_checkpoint.discard_output_and_register_recompute(core_attn_out) + self.qkv_up_checkpoint = None + + output, bias = self.linear_proj(core_attn_out) + return output, bias + + +class DSAMLASelfAttention(DSAMultiLatentAttention): + """DSA Multi-Latent Self-Attention layer. + + Takes input of shape [s, b, h] and returns output of the same shape. Adds + the indexer submodules (wq_b, wk, k_norm, weights_proj) on top of the + standard MLA projections. + """ + + def __init__( + self, + config: MLATransformerConfig, + submodules: DSASelfAttentionSubmodules, + layer_number: int, + attn_mask_type=AttnMaskType.padding, + is_mtp_layer: bool = False, + cp_comm_type: str | None = None, + model_comm_pgs=None, + pg_collection=None, + ): + super().__init__( + config=config, + submodules=submodules, + layer_number=layer_number, + attn_mask_type=attn_mask_type, + attention_type="self", + is_mtp_layer=is_mtp_layer, + cp_comm_type=cp_comm_type, + model_comm_pgs=model_comm_pgs, + pg_collection=pg_collection, + ) + + q_down_proj_kwargs: dict = {} + if submodules.linear_q_down_proj is TELinear: + q_down_proj_kwargs["parallel_mode"] = "duplicated" + elif submodules.linear_q_down_proj in ( + TEColumnParallelLinear, + ColumnParallelLinear, + ): + q_down_proj_kwargs["gather_output"] = False + else: + raise ValueError( + f"Unsupported linear_q_down_proj: {submodules.linear_q_down_proj}" + ) + + self.linear_q_down_proj = build_module( + submodules.linear_q_down_proj, + self.config.hidden_size, + self.config.q_lora_rank, + config=self.config, + init_method=self.config.init_method, + bias=False, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name="q_down_proj", + skip_weight_param_allocation=False, + **q_down_proj_kwargs, + ) + + self.linear_q_up_proj = build_module( + submodules.linear_q_up_proj, + self.config.q_lora_rank, + self.config.num_attention_heads * self.q_head_dim, + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=False, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name="q_up_proj", + ) + + kv_down_proj_kwargs: dict = {} + if submodules.linear_kv_down_proj is TELinear: + kv_down_proj_kwargs["parallel_mode"] = "duplicated" + elif submodules.linear_kv_down_proj in ( + TEColumnParallelLinear, + ColumnParallelLinear, + ): + kv_down_proj_kwargs["gather_output"] = False + else: + raise ValueError( + f"Unsupported linear_kv_down_proj: {submodules.linear_kv_down_proj}" + ) + + self.linear_kv_down_proj = build_module( + submodules.linear_kv_down_proj, + self.config.hidden_size, + self.config.kv_lora_rank + self.config.qk_pos_emb_head_dim, + config=self.config, + init_method=self.config.init_method, + bias=False, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name="kv_down_proj", + skip_weight_param_allocation=False, + **kv_down_proj_kwargs, + ) + + self.linear_kv_up_proj = build_module( + submodules.linear_kv_up_proj, + self.config.kv_lora_rank, + self.config.num_attention_heads + * (self.config.qk_head_dim + self.config.v_head_dim), + config=self.config, + init_method=self.config.init_method, + gather_output=False, + bias=False, + skip_bias_add=False, + is_expert=False, + tp_comm_buffer_name="kv_up_proj", + ) + + self.q_layernorm = build_module( + submodules.q_layernorm, + hidden_size=self.config.q_lora_rank, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + + self.kv_layernorm = build_module( + submodules.kv_layernorm, + hidden_size=self.config.kv_lora_rank, + config=self.config, + eps=self.config.layernorm_epsilon, + ) + + # Indexer submodules. + indexer_linear_kwargs = dict( + config=self.config, + init_method=self.config.init_method, + bias=False, + skip_bias_add=False, + is_expert=False, + parallel_mode="duplicated", + skip_weight_param_allocation=False, + ) + + self.wq_b = build_module( + submodules.wq_b, + input_size=self.config.q_lora_rank, + output_size=self.config.dsa_indexer_n_heads + * self.config.dsa_indexer_head_dim, + tp_comm_buffer_name="wq_b", + **indexer_linear_kwargs, + ) + self.wq_b.weight._skip_gather = True + + self.wk = build_module( + submodules.wk, + input_size=self.config.hidden_size, + output_size=self.config.dsa_indexer_head_dim, + tp_comm_buffer_name="wk", + **indexer_linear_kwargs, + ) + + # k_norm uses LayerNorm (not RMSNorm) per DSA design. Toggle config temporarily. + old_norm = self.config.normalization + assert config.normalization == "RMSNorm" + self.config.normalization = "LayerNorm" + self.k_norm = build_module( + submodules.k_norm, + hidden_size=self.config.dsa_indexer_head_dim, + config=self.config, + eps=1e-6, # hardcoded per DSA reference implementation + ) + self.config.normalization = old_norm + + self.weights_proj = build_module( + submodules.weights_proj, + input_size=self.config.hidden_size, + output_size=self.config.dsa_indexer_n_heads, + tp_comm_buffer_name="weights_proj", + **indexer_linear_kwargs, + ) + self.weights_proj.weight._skip_gather = True + + # 2026-04-30: freeze 4 个 indexer 模块,对齐 slime 默认行为 + # (slime/utils/arguments.py:197 `--freeze-params-name-list + # self_attention.wq_b self_attention.wk self_attention.k_norm + # self_attention.weights_proj`)。 + # 原因:slime 用 `q_compressed.detach()` 切断 indexer 与主反传链, + # indexer params 自然不参与 grad,DDP 跳过 grad ready check;我们 + # 之前用 `core_attn_out + scores.sum() * 0` reattach,目的同是绕 + # DDP assert,但下游 grad NaN×0=NaN 反而污染 indexer params。 + # 走 slime 的 freeze 方案,既对齐 reference,又消除 NaN 通路。 + # 可以用 AREAL_DSA_TRAIN_INDEXER=1 关闭 freeze(以备未来 RL 阶段 + # 解锁 indexer 训练)。 + import os as _os_freeze + + _train_indexer = _os_freeze.environ.get("AREAL_DSA_TRAIN_INDEXER", "0") == "1" + if not _train_indexer: + for _mod in (self.wq_b, self.wk, self.k_norm, self.weights_proj): + for _p in _mod.parameters(): + _p.requires_grad = False + + def get_absorb_query_key_value_tensors( + self, + hidden_states, + key_value_states=None, + position_ids=None, + packed_seq_params=None, + inference_context=None, + *, + inference_params=None, + ): + """Derive `query`, `key` and `value` tensors from `hidden_states`.""" + assert hidden_states.ndim == 3, ( + f"hidden_states should be 3D [s, b, n*h], got {hidden_states.ndim}D" + ) + assert packed_seq_params is not None + + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + inference_context, None, hidden_states, self.config, packed_seq_params + ) + rotary_pos_emb = self.rotary_pos_emb( + rotary_seq_len, packed_seq=packed_seq_params is not None + ) + # RotaryEmbedding returns a Tensor, YarnRotaryEmbedding returns + # (emb, mscale). softmax_scale is precomputed in __init__ via + # _yarn_get_mscale, so mscale at runtime is unused either way. + if isinstance(rotary_pos_emb, tuple): + rotary_pos_emb = rotary_pos_emb[0] + + cu_seqlens_q = packed_seq_params.cu_seqlens_q + cu_seqlens_kv = packed_seq_params.cu_seqlens_kv + + # QKV down projection + layer norm. + q_compressed, _ = self.linear_q_down_proj(hidden_states) + q_compressed = q_compressed.squeeze(1) + + kv_combined, _ = self.linear_kv_down_proj(hidden_states) + if self.config.sequence_parallel: + kv_combined = gather_from_sequence_parallel_region(kv_combined) + kv_compressed, k_pos_emb = torch.split( + kv_combined, + [self.config.kv_lora_rank, self.config.qk_pos_emb_head_dim], + dim=-1, + ) + kv_compressed = self.kv_layernorm(kv_compressed) + + # Absorb. + q_compressed = self.q_layernorm(q_compressed) + q, _ = self.linear_q_up_proj(q_compressed) + q = q.view( + *q.size()[:-1], + self.num_attention_heads_per_partition, + self.q_head_dim, + ) + q_no_pe, q_pos_emb = torch.split( + q, + [self.config.qk_head_dim, self.config.qk_pos_emb_head_dim], + dim=-1, + ) + + w_kc, w_vc = self.linear_kv_up_proj.weight.unflatten( + 0, + (-1, self.config.qk_head_dim + self.config.v_head_dim), + ).split([self.config.qk_head_dim, self.config.v_head_dim], dim=1) + + q_no_pe = torch.einsum("thd,hdm->thm", q_no_pe, w_kc) + + # Fuse rms_norm with layer_norm_weight so kv gradient all-reduces in TP. + kv_compressed = torch.nn.functional.rms_norm( + kv_compressed.float(), + normalized_shape=(kv_compressed.shape[-1],), + weight=self.linear_kv_up_proj.layer_norm_weight.float(), + eps=self.config.layernorm_epsilon, + ).to(kv_compressed.dtype) + + cp_group = parallel_state.get_context_parallel_group() + _cp_size = parallel_state.get_context_parallel_world_size() + + def _cp_all_gather(t): + if _cp_size <= 1: + return t + t = t.contiguous() + t_list = [torch.empty_like(t) for _ in range(_cp_size)] + torch.distributed.all_gather(t_list, t, group=cp_group) + return torch.cat(t_list, dim=0) + + k_pos_emb = _cp_all_gather(k_pos_emb) + kv_compressed = _cp_all_gather(kv_compressed) + + def fuse_rope(t_in, cu_seqlens, gathered=False): + # MLA interleaved rope: split into [x0,x2,...] + [x1,x3,...]. + x1 = t_in[..., 0::2] + x2 = t_in[..., 1::2] + t = torch.cat((x1, x2), dim=-1) + _cp_size = parallel_state.get_context_parallel_world_size() + if _cp_size <= 1: + return fused_apply_rotary_pos_emb_thd( + t, cu_seqlens, rotary_pos_emb.squeeze(0) + ) + from areal.models.mcore.lightning_attention import ( + _build_zigzag_redo_indices, + _build_zigzag_undo_indices, + ) + + if not gathered: + # t is the local zigzag slice; all-gather to get full zigzag tensor. + t_list = [torch.empty_like(t) for _ in range(_cp_size)] + torch.distributed.all_gather(t_list, t, group=cp_group) + t = torch.cat(t_list, dim=0) + # t is now the full zigzag tensor; unzigzag → rope → rezigzag. + _total = t.shape[0] + _undo = _build_zigzag_undo_indices(_total, _cp_size, cu_seqlens, t.device) + _redo = _build_zigzag_redo_indices(_undo) + t_seq = t[_undo] + # rotary_pos_emb is sequential (pos 0..total-1), not zigzag-sliced. + _rope_seq = rotary_pos_emb.squeeze(0) + t_seq = fused_apply_rotary_pos_emb_thd(t_seq, cu_seqlens, _rope_seq) + t_zz = t_seq[_redo] + if not gathered: + # Return only this rank's local slice. + _cp_rank = parallel_state.get_context_parallel_rank() + _local = _total // _cp_size + return t_zz[_cp_rank * _local : (_cp_rank + 1) * _local] + return t_zz + + q_pos_emb = fuse_rope(q_pos_emb, cu_seqlens_q, gathered=False) + k_pos_emb = fuse_rope(k_pos_emb, cu_seqlens_kv, gathered=True) + + query = torch.cat([q_no_pe, q_pos_emb], dim=-1).contiguous() + key = torch.cat([kv_compressed, k_pos_emb], dim=-1).contiguous() + + # Indexer. Detach to cut gradient flow from indexer into base projections. + q_compressed = q_compressed.detach() + hidden_states = hidden_states.detach() + rotary_pos_emb = rotary_pos_emb.detach() + + index_q, _ = self.wq_b(q_compressed) + index_q = index_q.view( + *index_q.size()[:-1], + self.config.dsa_indexer_n_heads, + self.config.dsa_indexer_head_dim, + ) + if self.config.sequence_parallel: + index_q = gather_from_sequence_parallel_region(index_q) + + index_k, _ = self.wk(hidden_states) + index_k = self.k_norm(index_k.squeeze(1).float()).bfloat16() + if self.config.sequence_parallel: + index_k = gather_from_sequence_parallel_region(index_k) + index_k = _cp_all_gather(index_k).unsqueeze(1) + + head_weights = WeightLinearFunction.apply( + hidden_states, self.weights_proj.weight, None, torch.float32 + ) + head_weights = head_weights.squeeze(1) * ( + (self.config.dsa_indexer_n_heads**-0.5) + * (self.config.dsa_indexer_head_dim**-0.5) + ) + if self.config.sequence_parallel: + head_weights = gather_from_sequence_parallel_region(head_weights) + + # GLM-5.1 indexer weight layout: first rope_dim dims are RoPE, + # remaining dims are position-independent (matches SGLang/slime). + index_q_pe, index_q_no_pe = torch.split( + index_q, + [ + self.config.qk_pos_emb_head_dim, + self.config.dsa_indexer_head_dim - self.config.qk_pos_emb_head_dim, + ], + dim=-1, + ) + index_q_pe = fuse_rope(index_q_pe, cu_seqlens_q, gathered=False) + index_query = torch.cat([index_q_pe, index_q_no_pe], dim=-1) + + index_k_pe, index_k_no_pe = torch.split( + index_k, + [ + self.config.qk_pos_emb_head_dim, + self.config.dsa_indexer_head_dim - self.config.qk_pos_emb_head_dim, + ], + dim=-1, + ) + index_k_pe = fuse_rope(index_k_pe, cu_seqlens_kv, gathered=True) + index_key = torch.cat([index_k_pe, index_k_no_pe], dim=-1) + + return query, key, w_vc, index_query, index_key, head_weights + + def get_query_key_value_tensors(self) -> NoReturn: + raise NotImplementedError( + "DSAMLASelfAttention uses get_absorb_query_key_value_tensors(); " + "the standard path is not supported." + ) diff --git a/areal/models/mcore/glm5_megatron_bridge.py b/areal/models/mcore/glm5_megatron_bridge.py new file mode 100644 index 0000000000..9bcef8c914 --- /dev/null +++ b/areal/models/mcore/glm5_megatron_bridge.py @@ -0,0 +1,355 @@ +# SPDX-License-Identifier: Apache-2.0 + +"""Megatron-Bridge registration for GLM-5.1 (GlmMoeDsaForCausalLM). + +Registers the GLM-5.1 architecture with NVIDIA's open-source megatron-bridge, +enabling ``bridge_type: megatron-bridge`` in AReaL for this model. + +GLM-5.1 uses Multi-Latent Attention (MLA) like DeepSeek V3, plus DSA +(Dynamic Sparse Attention) indexer weights. +""" + +import os +from collections.abc import Mapping +from functools import partial + +import torch +from megatron.bridge.models.conversion.mapping_registry import MegatronMappingRegistry +from megatron.bridge.models.conversion.model_bridge import ( + MegatronModelBridge, + WeightConversionTask, +) +from megatron.bridge.models.conversion.param_mapping import AutoMapping, GatedMLPMapping +from megatron.bridge.models.deepseek.common import get_common_mapping_list +from megatron.bridge.models.hf_pretrained.causal_lm import PreTrainedCausalLM +from megatron.bridge.models.mla_provider import MLAModelProvider +from megatron.core.models.gpt.gpt_layer_specs import get_gpt_decoder_block_spec +from megatron.core.models.gpt.gpt_model import GPTModel + +from areal.utils import logging + +logger = logging.getLogger("GLM5Bridge") + +try: + import transformer_engine # noqa: F401 + + HAVE_TE = True +except (ImportError, ModuleNotFoundError): + HAVE_TE = False + + +# --------------------------------------------------------------------------- +# GLM-5.1 Bridge +# --------------------------------------------------------------------------- + +# DSA indexer weight definitions: (megatron_name, hf_name) +_DSA_INDEXER_MAPPINGS: list[tuple[str, str]] = [ + ( + "decoder.layers.*.self_attention.wq_b.weight", + "model.layers.*.self_attn.indexer.wq_b.weight", + ), + ( + "decoder.layers.*.self_attention.wk.weight", + "model.layers.*.self_attn.indexer.wk.weight", + ), + ( + "decoder.layers.*.self_attention.weights_proj.weight", + "model.layers.*.self_attn.indexer.weights_proj.weight", + ), + ( + "decoder.layers.*.self_attention.k_norm.weight", + "model.layers.*.self_attn.indexer.k_norm.weight", + ), + ( + "decoder.layers.*.self_attention.k_norm.bias", + "model.layers.*.self_attn.indexer.k_norm.bias", + ), +] + + +def _get_rope_theta(hf_config) -> float: + """Extract rope_theta from HF config, handling GLM-5.1's nested structure.""" + if hasattr(hf_config, "rope_parameters") and isinstance( + hf_config.rope_parameters, dict + ): + return float(hf_config.rope_parameters.get("rope_theta", 10000.0)) + return float(getattr(hf_config, "rope_theta", 10000.0)) + + +@MegatronModelBridge.register_bridge( + source="GlmMoeDsaForCausalLM", + target=GPTModel, + provider=MLAModelProvider, + model_type="glm_moe_dsa", +) +class GLM5Bridge(MegatronModelBridge): + """Megatron Bridge for GLM-5.1 (GlmMoeDsa) with MLA + MoE + DSA.""" + + def provider_bridge(self, hf_pretrained: PreTrainedCausalLM) -> MLAModelProvider: + provider = super().provider_bridge(hf_pretrained) + hf_config = hf_pretrained.config + + # Layer spec + provider.transformer_layer_spec = partial( + get_gpt_decoder_block_spec, use_transformer_engine=HAVE_TE + ) + + # Architecture basics + provider.normalization = "RMSNorm" + provider.gated_linear_unit = True + provider.position_embedding_type = "rope" + provider.add_bias_linear = False + provider.share_embeddings_and_output_weights = False + provider.qk_layernorm = True + provider.multi_latent_attention = True + + # MoE + provider.moe_grouped_gemm = True + provider.moe_router_pre_softmax = True + provider.moe_token_dispatcher_type = "alltoall" + provider.moe_router_load_balancing_type = "none" + provider.moe_shared_expert_overlap = False + provider.moe_router_score_function = "sigmoid" + provider.moe_router_enable_expert_bias = True + provider.moe_router_bias_update_rate = 0.0 + provider.moe_router_dtype = "fp32" + provider.moe_permute_fusion = True + provider.moe_z_loss_coeff = 3.5e-06 + + # Fusions + provider.apply_rope_fusion = False + provider.bias_activation_fusion = True + provider.bias_dropout_fusion = True + provider.cross_entropy_loss_fusion = False + provider.masked_softmax_fusion = True + provider.persist_layer_norm = True + provider.gradient_accumulation_fusion = True + + # Misc + provider.hidden_dropout = 0.0 + provider.attention_softmax_in_fp32 = True + provider.disable_bf16_reduced_precision_matmul = True + provider.make_vocab_size_divisible_by = 128 + provider.seq_length = getattr(hf_config, "max_position_embeddings", 4096) + + # Rope — GLM-5.1 stores rope_theta in rope_parameters dict + provider.rotary_base = _get_rope_theta(hf_config) + provider.rotary_scaling_factor = 1.0 + provider.rope_type = "rope" + + # MoE layer frequency + provider.moe_layer_freq = [0] * hf_config.first_k_dense_replace + [1] * ( + hf_config.num_hidden_layers - hf_config.first_k_dense_replace + ) + provider.moe_shared_expert_intermediate_size = ( + hf_config.moe_intermediate_size * hf_config.n_shared_experts + ) + + # MTP + mtp_num_layers = getattr(hf_config, "num_nextn_predict_layers", None) + if os.environ.get("AREAL_DISABLE_MTP", "0") == "1" and mtp_num_layers: + logger.warning( + f"AREAL_DISABLE_MTP=1: overriding mtp_num_layers from {mtp_num_layers} to 0" + ) + mtp_num_layers = 0 + provider.mtp_num_layers = mtp_num_layers + + # DSA (Dynamic Sparse Attention) — set on provider so the internal + # TransformerConfig picks them up when creating DSAMLASelfAttention. + if ( + getattr(hf_config, "index_topk", None) is not None + and getattr(hf_config, "index_n_heads", None) is not None + ): + provider.dsa_indexer_n_heads = hf_config.index_n_heads + provider.dsa_indexer_head_dim = hf_config.index_head_dim + provider.dsa_indexer_topk = hf_config.index_topk + provider.dsa_indexer_loss_coeff = getattr( + hf_config, "dsa_indexer_loss_coeff", 0.0 + ) + provider.dsa_indexer_use_sparse_loss = getattr( + hf_config, "dsa_indexer_use_sparse_loss", False + ) + + return provider + + def mapping_registry(self) -> MegatronMappingRegistry: + mapping_list = get_common_mapping_list() + + # Expert bias + mapping_list.append( + AutoMapping( + megatron_param="decoder.layers.*.mlp.router.expert_bias", + hf_param="model.layers.*.mlp.gate.e_score_correction_bias", + ) + ) + + # DSA indexer weights + for mcore_name, hf_name in _DSA_INDEXER_MAPPINGS: + mapping_list.append( + AutoMapping( + megatron_param=mcore_name, + hf_param=hf_name, + ) + ) + + # MTP layer mappings (if present) + mapping_list.extend(self._get_mtp_mappings()) + + return MegatronMappingRegistry(*mapping_list) + + def maybe_modify_converted_hf_weight( + self, + task: WeightConversionTask, + converted_weights_dict: dict[str, torch.Tensor], + hf_state_dict: Mapping[str, torch.Tensor], + ) -> dict[str, torch.Tensor]: + """Add rotary inv_freq to HF state dict if the original checkpoint had it.""" + global_name = task.global_param_name + if not global_name.startswith("decoder.layers.") or not global_name.endswith( + ".input_layernorm.weight" + ): + return converted_weights_dict + + parts = global_name.split(".") + if len(parts) < 4 or not parts[2].isdigit(): + return converted_weights_dict + + inv_freq_prefix = "model.layers." + inv_freq_suffix = ".self_attn.rotary_emb.inv_freq" + layer_idx = int(parts[2]) + inv_freq_key = f"{inv_freq_prefix}{layer_idx}{inv_freq_suffix}" + if inv_freq_key in converted_weights_dict: + return converted_weights_dict + + has_inv_freq = getattr(self, "_glm5_has_inv_freq", None) + if has_inv_freq is None: + has_inv_freq = any( + key.startswith(inv_freq_prefix) and key.endswith(inv_freq_suffix) + for key in hf_state_dict.keys() + ) + self._glm5_has_inv_freq = has_inv_freq + if not has_inv_freq: + return converted_weights_dict + + inv_freq = getattr(self, "_glm5_inv_freq", None) + if inv_freq is None: + rotary_dim = self.hf_config.qk_rope_head_dim + rotary_base = _get_rope_theta(self.hf_config) + inv_freq = 1.0 / ( + rotary_base + ** (torch.arange(0, rotary_dim, 2, dtype=torch.float32) / rotary_dim) + ) + self._glm5_inv_freq = inv_freq + + if converted_weights_dict: + ref = next(iter(converted_weights_dict.values())) + if inv_freq.device != ref.device: + inv_freq = inv_freq.to(device=ref.device) + self._glm5_inv_freq = inv_freq + + converted_weights_dict[inv_freq_key] = inv_freq + return converted_weights_dict + + # --------------------------------------------------------------- + # MTP layer mappings + # --------------------------------------------------------------- + + def _get_mtp_mappings(self) -> list: + hf_config = getattr(self, "hf_config", None) + if hf_config is None: + return [] + num_mtp = getattr(hf_config, "num_nextn_predict_layers", 0) + if not num_mtp or num_mtp <= 0: + return [] + + num_layers = hf_config.num_hidden_layers + mappings: list = [] + + _MTP_LAYER_MAPPINGS = { + "mtp.layers.*.transformer_layer.input_layernorm.weight": "model.layers.*.input_layernorm.weight", + "mtp.layers.*.transformer_layer.self_attention.linear_proj.weight": "model.layers.*.self_attn.o_proj.weight", + "mtp.layers.*.transformer_layer.pre_mlp_layernorm.weight": "model.layers.*.post_attention_layernorm.weight", + "mtp.layers.*.transformer_layer.mlp.linear_fc1.layer_norm_weight": "model.layers.*.post_attention_layernorm.weight", + "mtp.layers.*.transformer_layer.self_attention.linear_kv_down_proj.weight": "model.layers.*.self_attn.kv_a_proj_with_mqa.weight", + "mtp.layers.*.transformer_layer.self_attention.linear_kv_up_proj.weight": "model.layers.*.self_attn.kv_b_proj.weight", + "mtp.layers.*.transformer_layer.self_attention.linear_kv_up_proj.layer_norm_weight": "model.layers.*.self_attn.kv_a_layernorm.weight", + "mtp.layers.*.transformer_layer.kv_layernorm.weight": "model.layers.*.self_attn.kv_a_layernorm.weight", + "mtp.layers.*.transformer_layer.mlp.linear_fc2.weight": "model.layers.*.mlp.down_proj.weight", + "mtp.layers.*.transformer_layer.mlp.router.weight": "model.layers.*.mlp.gate.weight", + "mtp.layers.*.transformer_layer.mlp.router.expert_bias": "model.layers.*.mlp.gate.e_score_correction_bias", + "mtp.layers.*.transformer_layer.mlp.experts.linear_fc2.weight*": "model.layers.*.mlp.experts.*.down_proj.weight", + "mtp.layers.*.transformer_layer.mlp.shared_experts.linear_fc2.weight": "model.layers.*.mlp.shared_experts.down_proj.weight", + "mtp.layers.*.transformer_layer.self_attention.linear_q_down_proj.weight": "model.layers.*.self_attn.q_a_proj.weight", + "mtp.layers.*.transformer_layer.self_attention.linear_q_up_proj.weight": "model.layers.*.self_attn.q_b_proj.weight", + "mtp.layers.*.transformer_layer.self_attention.linear_q_up_proj.layer_norm_weight": "model.layers.*.self_attn.q_a_layernorm.weight", + "mtp.layers.*.transformer_layer.q_layernorm.weight": "model.layers.*.self_attn.q_a_layernorm.weight", + } + + for mtp_idx in range(num_mtp): + layer_idx = mtp_idx + num_layers + + # MTP-specific weights + mappings.extend( + [ + AutoMapping( + megatron_param=f"mtp.layers.{mtp_idx}.enorm.weight", + hf_param=f"model.layers.{layer_idx}.enorm.weight", + ), + AutoMapping( + megatron_param=f"mtp.layers.{mtp_idx}.hnorm.weight", + hf_param=f"model.layers.{layer_idx}.hnorm.weight", + ), + AutoMapping( + megatron_param=f"mtp.layers.{mtp_idx}.eh_proj.weight", + hf_param=f"model.layers.{layer_idx}.eh_proj.weight", + ), + AutoMapping( + megatron_param=f"mtp.layers.{mtp_idx}.final_layernorm.weight", + hf_param=f"model.layers.{layer_idx}.shared_head.norm.weight", + ), + ] + ) + + # Standard layer mappings adapted for MTP + for mcore_pat, hf_pat in _MTP_LAYER_MAPPINGS.items(): + mappings.append( + AutoMapping( + megatron_param=mcore_pat.replace("*", str(mtp_idx), 1), + hf_param=hf_pat.replace("*", str(layer_idx), 1), + ) + ) + + # GatedMLP for MTP + mappings.extend( + [ + GatedMLPMapping( + megatron_param=f"mtp.layers.{mtp_idx}.transformer_layer.mlp.linear_fc1.weight", + gate=f"model.layers.{layer_idx}.mlp.gate_proj.weight", + up=f"model.layers.{layer_idx}.mlp.up_proj.weight", + ), + GatedMLPMapping( + megatron_param=f"mtp.layers.{mtp_idx}.transformer_layer.mlp.experts.linear_fc1.weight*", + gate=f"model.layers.{layer_idx}.mlp.experts.*.gate_proj.weight", + up=f"model.layers.{layer_idx}.mlp.experts.*.up_proj.weight", + ), + GatedMLPMapping( + megatron_param=f"mtp.layers.{mtp_idx}.transformer_layer.mlp.shared_experts.linear_fc1.weight", + gate=f"model.layers.{layer_idx}.mlp.shared_experts.gate_proj.weight", + up=f"model.layers.{layer_idx}.mlp.shared_experts.up_proj.weight", + ), + ] + ) + + # DSA indexer weights for MTP layers + for mcore_pat, hf_pat in _DSA_INDEXER_MAPPINGS: + mappings.append( + AutoMapping( + megatron_param=mcore_pat.replace( + "decoder.layers.*", + f"mtp.layers.{mtp_idx}.transformer_layer", + ), + hf_param=hf_pat.replace("layers.*", f"layers.{layer_idx}"), + ) + ) + + return mappings diff --git a/areal/models/mcore/registry.py b/areal/models/mcore/registry.py index 3ed2da8544..73bf9ba8a7 100644 --- a/areal/models/mcore/registry.py +++ b/areal/models/mcore/registry.py @@ -17,6 +17,11 @@ hf_to_mcore_config_bailing_moe, make_mcore_layer_specs_bailing_moe, ) +from areal.models.mcore.deepseek_v3 import ( + _has_dsa, + hf_to_mcore_config_deepseek_v3, + make_mcore_layer_specs_deepseek_v3, +) from areal.models.mcore.qwen3 import ( hf_to_mcore_config_qwen3_dense, make_mcore_layer_specs_qwen3_dense, @@ -119,6 +124,12 @@ def unwrap_to_gpt_model(model: torch.nn.Module) -> GPTModel: "BailingHybridForCausalLM", } +_DEEPSEEK_V3_ARCHITECTURES = { + "DeepseekV3ForCausalLM", + "GlmMoeDsaForCausalLM", + "Glm4MoeForCausalLM", +} + def _is_bailing(hf_config: PretrainedConfig) -> bool: """Return True if hf_config belongs to the BailingMoeV2.5 family.""" @@ -126,6 +137,30 @@ def _is_bailing(hf_config: PretrainedConfig) -> bool: return bool(architectures) and architectures[0] in _BAILING_ARCHITECTURES +def _supplement_dsa_config(hf_config: PretrainedConfig, tf_config) -> None: + """Backfill DSA-specific fields onto tf_config when applicable. + + Megatron-Bridge's TransformerConfig does not natively include DSA + (Dynamic Sparse Attention) parameters. When the model is DSA-enabled, + copy the indexer settings from the HF config so DSAMLASelfAttention + can be constructed correctly. + """ + if not _has_dsa(hf_config): + return + dsa_attrs = { + "dsa_indexer_n_heads": hf_config.index_n_heads, + "dsa_indexer_head_dim": hf_config.index_head_dim, + "dsa_indexer_topk": hf_config.index_topk, + "dsa_indexer_loss_coeff": getattr(hf_config, "dsa_indexer_loss_coeff", 0.0), + "dsa_indexer_use_sparse_loss": getattr( + hf_config, "dsa_indexer_use_sparse_loss", False + ), + } + for attr, val in dsa_attrs.items(): + if getattr(tf_config, attr, None) is None: + setattr(tf_config, attr, val) + + # Model registry for different architectures def make_hf_and_mcore_config( hf_path: str, @@ -142,6 +177,7 @@ def make_hf_and_mcore_config( if hasattr(hf_config, "_name_or_path"): hf_config._name_or_path = hf_path tf_config = bridge.transformer_config + _supplement_dsa_config(hf_config, tf_config) return hf_config, tf_config else: hf_config: PretrainedConfig = AutoConfig.from_pretrained( @@ -154,6 +190,8 @@ def make_hf_and_mcore_config( return hf_config, hf_to_mcore_config_qwen3_dense(hf_config, dtype) elif architecture in _BAILING_ARCHITECTURES: return hf_config, hf_to_mcore_config_bailing_moe(hf_config, dtype) + elif architecture in _DEEPSEEK_V3_ARCHITECTURES: + return hf_config, hf_to_mcore_config_deepseek_v3(hf_config, dtype) else: raise ValueError( f"Architecture not registered for config conversion: {architecture}." @@ -167,6 +205,8 @@ def make_mcore_layer_specs(hf_config: PretrainedConfig, tf_config: TransformerCo return make_mcore_layer_specs_qwen3_dense(tf_config, use_te=True) elif architecture in _BAILING_ARCHITECTURES: return make_mcore_layer_specs_bailing_moe(tf_config, hf_config, use_te=True) + elif architecture in _DEEPSEEK_V3_ARCHITECTURES: + return make_mcore_layer_specs_deepseek_v3(tf_config, hf_config, use_te=True) else: raise ValueError( f"Architecture not registered for config conversion: {architecture}." @@ -212,8 +252,14 @@ def make_mcore_model( # one that megatron-bridge does not provide. The lambda matches # megatron-bridge's call signature # ``(config, vp_stage=None) -> TransformerBlockSubmodules``. + # * DSA models (GLM-5.1): need DSAMLASelfAttention with indexer. # * Bailing-MoE V2.5: per-layer heterogeneous Lightning + MLA. - if _is_bailing(hf_config): + if _has_dsa(hf_config): + _dsa_specs = make_mcore_layer_specs(hf_config, tf_config) + provider.transformer_layer_spec = ( + lambda config, vp_stage=None, _s=_dsa_specs: _s + ) + elif _is_bailing(hf_config): _bailing_specs = make_mcore_layer_specs(hf_config, tf_config) provider.transformer_layer_spec = ( lambda config, vp_stage=None, _s=_bailing_specs: _s diff --git a/docs/en/best_practices/migrate_to_megatron_bridge.md b/docs/en/best_practices/migrate_to_megatron_bridge.md index b49bd86205..43a9500963 100644 --- a/docs/en/best_practices/migrate_to_megatron_bridge.md +++ b/docs/en/best_practices/migrate_to_megatron_bridge.md @@ -20,12 +20,13 @@ new model architecture under the `megatron-bridge` backend. ## When to use which -| Need | Prefer | -| -------------------------------------------------- | ----------------- | -| Existing setups, disk-based HF weight load/save | `mbridge` | -| Tree-attention training in `MegatronEngine` | `mbridge` | -| PEFT/LoRA support | `megatron-bridge` | -| Architectures NVIDIA upstream maintains officially | `megatron-bridge` | +| Need | Prefer | +| ------------------------------------------------------ | ----------------- | +| Existing setups, disk-based HF weight load/save | `mbridge` | +| Tree-attention training in `MegatronEngine` | `mbridge` | +| PEFT/LoRA support | `megatron-bridge` | +| Newer model architectures (GLM-5.1 DSA, GLM-4.7-Flash) | `megatron-bridge` | +| Architectures NVIDIA upstream maintains officially | `megatron-bridge` | `megatron-bridge` is the long-term direction: it has PEFT support, NVIDIA upstream maintenance, and broader model coverage. `mbridge` is being deprecated but remains the @@ -57,17 +58,21 @@ new model adapters. Architectures registered in AReaL's `mcore/registry.py` and routed through both backends: -| HF architecture | `mbridge` | `megatron-bridge` | Notes | -| ----------------------------- | --------- | ----------------- | ------------------------------------- | -| `Qwen3ForCausalLM` | ✅ | ✅ (NV upstream) | Dense. | -| `BailingMoeV2_5ForCausalLM` | ✅ | ✅ | Heterogeneous Lightning + MLA layers. | -| `BailingMoeLinearForCausalLM` | ✅ | ✅ | Shares `BailingMoeV25Bridge` adapter. | -| `BailingHybridForCausalLM` | ✅ | ✅ | Shares `BailingMoeV25Bridge` adapter. | +| HF architecture | `mbridge` | `megatron-bridge` | Notes | +| -------------------------------- | --------- | ----------------- | --------------------------------------------------- | +| `Qwen3ForCausalLM` | ✅ | ✅ (NV upstream) | Dense. | +| `BailingMoeV2_5ForCausalLM` | ✅ | ✅ | Heterogeneous Lightning + MLA layers. | +| `BailingMoeLinearForCausalLM` | ✅ | ✅ | Shares `BailingMoeV25Bridge` adapter. | +| `BailingHybridForCausalLM` | ✅ | ✅ | Shares `BailingMoeV25Bridge` adapter. | +| `DeepseekV3ForCausalLM` | ✅ | ✅ (NV upstream) | Homogeneous MLA + MoE. | +| `GlmMoeDsaForCausalLM` (GLM-5.1) | ✅ | ✅ | MLA + MoE + DSA (Dynamic Sparse Attention) indexer. | +| `Glm4MoeForCausalLM` | ✅ | ✅ (NV upstream) | GLM-4.7-Flash class. | Custom adapters live under `areal/models/mcore/`: -- mbridge subclasses: `bailing_moe_bridge.py` -- megatron-bridge subclasses: `bailing_moe_megatron_bridge.py` +- mbridge subclasses: `bailing_moe_bridge.py`, `deepseek_v3_bridge.py` +- megatron-bridge subclasses: `bailing_moe_megatron_bridge.py`, + `glm5_megatron_bridge.py` ## How registry dispatch works @@ -77,17 +82,24 @@ Custom adapters live under `areal/models/mcore/`: - With `bridge_type="mbridge"`, returns `(bridge.hf_config, bridge.config)`. - With `bridge_type="megatron-bridge"`, returns - `(bridge.hf_pretrained.config, bridge.transformer_config)`. + `(bridge.hf_pretrained.config, bridge.transformer_config)`. Also backfills + DSA-specific fields onto the transformer config via `_supplement_dsa_config()` when + the model is DSA-enabled. - `make_mcore_model(hf_config, tf_config, mcore_config, bridge, bridge_type, ...)`: - With `mbridge`, calls `bridge.get_model(...)`. + - With `megatron-bridge`, calls `bridge.to_megatron_provider(load_weights=False)` and configures the provider with the current TP/PP/CP/EP context, then `provider.provide_distributed_model(...)`. Before configuring the provider, it - overrides `provider.transformer_layer_spec` for models whose layer structure - megatron-bridge's default spec doesn't express — currently the **Bailing-MoE V2.5 - family**, which uses AReaL's heterogeneous Lightning + MLA layer spec. + overrides `provider.transformer_layer_spec` for two cases NVIDIA's default spec + doesn't cover: + + - **DSA models** (e.g., GLM-5.1): uses AReaL's `DSAMLASelfAttention` so the indexer + modules (`wq_b`, `wk`, `k_norm`, `weights_proj`) are wired correctly. + - **Bailing-MoE V2.5 family**: uses AReaL's heterogeneous Lightning + MLA layer + spec. ## Adding a new model under `megatron-bridge` @@ -137,7 +149,8 @@ import areal.models.mcore.my_model_megatron_bridge # noqa: F401 # register bri If the model has custom attention modules that megatron-bridge's default `get_gpt_decoder_block_spec` cannot express, add a branch in -`registry.make_mcore_model()` to inject your spec (mirror the `_is_bailing` branch). +`registry.make_mcore_model()` to inject your spec (mirror the `_has_dsa` / `_is_bailing` +branches). ## Common pitfalls From 9aad59f9e1b8cee738c1c50852bca15f502bd337 Mon Sep 17 00:00:00 2001 From: "chucai.dzq" Date: Thu, 28 May 2026 12:24:13 +0800 Subject: [PATCH 2/2] fix(dsa): align IndexerFunction.backward grad count with forward inputs forward(ctx, ...) takes 7 inputs after ctx (index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, topk, topk_indices), so PyTorch's autograd contract requires backward to return exactly 7 values. The previous return tuple had 10 entries (3 grads + 7 Nones); fix to 7 (3 grads + 4 Nones) so non-frozen indexer training does not crash with "function returned an incorrect number of gradients". This bug was masked in validated experiments because the default AREAL_DSA_TRAIN_INDEXER=0 freezes all 4 indexer parameter modules (wq_b, wk, k_norm, weights_proj) via requires_grad=False, and dsa_mla_attention.py additionally detaches q_compressed and hidden_states before the indexer call. With all upstream tensors and parameters requires_grad=False, autograd skips the IndexerFunction backward entirely and the buggy return tuple is never inspected. Setting AREAL_DSA_TRAIN_INDEXER=1 (per the code comment, "for future RL stages") would have triggered the crash. This fix changes no numerical behavior on the validated default-freeze path; it just unblocks the indexer-training path. --- areal/experimental/ops/dsa/indexer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/areal/experimental/ops/dsa/indexer.py b/areal/experimental/ops/dsa/indexer.py index d0067972e3..afc27776dd 100644 --- a/areal/experimental/ops/dsa/indexer.py +++ b/areal/experimental/ops/dsa/indexer.py @@ -74,7 +74,9 @@ def backward(ctx, grad_scores, grad_indices): grad_q, grad_w, grad_k = indexer_bwd_interface( index_q, weights, index_k, topk_indices, grad_scores ) - return grad_q, grad_k, grad_w, None, None, None, None, None, None, None + # 7 returns matching forward inputs (excluding ctx): + # index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, topk, topk_indices + return grad_q, grad_k, grad_w, None, None, None, None def lighting_indexer(