Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions areal/engine/megatron_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,13 @@
import areal.models.mcore.bailing_moe_bridge # noqa: F401 # register bridge
except ImportError:
pass
try:
import areal.models.mcore.deepseek_v3_bridge # noqa: F401 # register bridge
except ImportError:
pass
# megatron-bridge adapters do not depend on mbridge and register on import.
import areal.models.mcore.bailing_moe_megatron_bridge # noqa: F401 # register bridge
import areal.models.mcore.glm5_megatron_bridge # noqa: F401 # register bridge
from areal.api import (
FinetuneSpec,
InferenceEngine,
Expand Down
1 change: 1 addition & 0 deletions areal/experimental/ops/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: Apache-2.0
1 change: 1 addition & 0 deletions areal/experimental/ops/dsa/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# SPDX-License-Identifier: Apache-2.0
104 changes: 104 additions & 0 deletions areal/experimental/ops/dsa/indexer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
# SPDX-License-Identifier: Apache-2.0

import os

import torch

from .tilelang_indexer_bwd import indexer_bwd_interface
from .tilelang_indexer_fwd import indexer_fwd_interface

# DSA indexer topk_indices recording for routing comparison.
# When AREAL_DUMP_ROUTING is set, each layer's forward appends its
# topk_indices here. Cleared after dump in megatron_engine.py.
_recorded_dsa_indices: list[torch.Tensor] = []
_record_dsa = bool(os.environ.get("AREAL_DUMP_ROUTING", ""))


def get_recorded_dsa_indices() -> list[torch.Tensor]:
return _recorded_dsa_indices


def clear_recorded_dsa_indices():
_recorded_dsa_indices.clear()


def pytorch_extract_topk_scores(logits, topk_indices, dim=-1):
valid_mask = topk_indices != -1
safe_indices = topk_indices.clamp(min=0).to(torch.int64)
scores = torch.gather(logits, dim=dim, index=safe_indices)
scores = torch.where(valid_mask, scores, float("-inf"))
return scores


class IndexerFunction(torch.autograd.Function):
@staticmethod
def forward(
ctx,
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
topk: int,
topk_indices: torch.Tensor | None = None,
):
_, head_num, _ = index_q.shape
logits = indexer_fwd_interface(
index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, clean_logits=True
)
if topk_indices is None:
sorted_indices = torch.argsort(-logits, dim=-1, stable=True)
topk_indices = sorted_indices[..., :topk].to(torch.int32)
index_score = torch.gather(
logits, dim=-1, index=topk_indices.to(torch.int64)
)
topk_indices = topk_indices.masked_fill(index_score == -torch.inf, -1)

index_score = pytorch_extract_topk_scores(logits, topk_indices)

if _record_dsa:
_recorded_dsa_indices.append(topk_indices.detach().cpu())

ctx.save_for_backward(
index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, topk_indices
)
ctx.topk = topk
ctx.head_num = head_num
return index_score, topk_indices

@staticmethod
def backward(ctx, grad_scores, grad_indices):
index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, topk_indices = (
ctx.saved_tensors
)
grad_q, grad_w, grad_k = indexer_bwd_interface(
index_q, weights, index_k, topk_indices, grad_scores
)
# 7 returns matching forward inputs (excluding ctx):
# index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, topk, topk_indices
return grad_q, grad_k, grad_w, None, None, None, None


def lighting_indexer(
index_q: torch.Tensor,
index_k: torch.Tensor,
weights: torch.Tensor,
cu_seqlen_ks: torch.Tensor,
cu_seqlen_ke: torch.Tensor,
topk: int,
topk_indices: torch.Tensor | None = None,
):
weights = weights.squeeze(-1)
return IndexerFunction.apply(
index_q, index_k, weights, cu_seqlen_ks, cu_seqlen_ke, topk, topk_indices
)


def generate_varlen_mask_params(cu_seqlens):
seq_len = cu_seqlens[-1].item()
q_indices = torch.arange(0, seq_len, device=cu_seqlens.device)
seq_indices = torch.searchsorted(cu_seqlens, q_indices, right=True) - 1
starts = cu_seqlens[seq_indices]
ends = q_indices + 1
assert torch.all((ends - starts) > 0)
return starts, ends
163 changes: 163 additions & 0 deletions areal/experimental/ops/dsa/sparse_mla.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
# SPDX-License-Identifier: Apache-2.0

import os

import torch

from .tilelang_sparse_mla_bwd import sparse_mla_bwd
from .tilelang_sparse_mla_fwd import sparse_mla_fwd_interface


def _pytorch_sparse_mla_bwd(q, kv, tl_out, grad_output, indices, tl_lse, scaling):
"""G11s: chunked pure-pytorch SparseMLA backward.

Memory: materializing k_gathered at shape (S, G, TOPK, D_full) can be
~22 GB per rank. Chunk along S dimension (64 rows at a time ≈ 300 MB).

Math correctness: kernel's tl_lse is log2(sum_i exp(score_i * scaling)).
softmax prob = exp(score_i * scaling - tl_lse * ln(2)).
"""
S, H, D_full = q.shape
S_kv, G, _ = kv.shape
D_v = tl_out.shape[-1]
TOPK = indices.shape[-1]
H_per_group = H // G
ln2 = 0.6931471805599453

q_f = q.float()
kv_f = kv.float()
do_f = grad_output.float()
o_f = tl_out.float()

# Precompute Delta = sum_d o * do (small, fp32)
delta_full = (o_f * do_f).sum(dim=-1) # (S, H)

# Outputs
dq = torch.zeros(S, H, D_full, device=q.device, dtype=torch.float32)
dkv_fp32 = torch.zeros(S_kv, G, D_full, device=q.device, dtype=torch.float32)

CHUNK = 32
for s0 in range(0, S, CHUNK):
s1 = min(s0 + CHUNK, S)
cs = s1 - s0 # chunk size
idx_c = indices[s0:s1] # (cs, G, TOPK)
safe_c = idx_c.clamp(min=0).long() # (cs, G, TOPK)
valid_c = idx_c != -1 # (cs, G, TOPK)

# Gather K per group: (cs, G, TOPK, D_full)
k_c = torch.stack(
[
kv_f[:, g, :]
.index_select(0, safe_c[:, g, :].reshape(-1))
.view(cs, TOPK, D_full)
for g in range(G)
],
dim=1,
)
# Compute per-head via per-group q split (avoid 8x replication)
q_c = q_f[s0:s1].view(cs, G, H_per_group, D_full) # (cs, G, Hg, D_full)
do_c = do_f[s0:s1].view(cs, G, H_per_group, D_v) # (cs, G, Hg, D_v)
v_c = k_c[..., :D_v] # (cs, G, TOPK, D_v)
lse_c = tl_lse[s0:s1].view(cs, G, H_per_group) # (cs, G, Hg)
delta_c = delta_full[s0:s1].view(cs, G, H_per_group) # (cs, G, Hg)

# score[cs, g, hg, t] = q_c @ k_c
score = torch.einsum("cghd,cgtd->cght", q_c, k_c) * scaling
# valid mask broadcast
valid_chg = valid_c.unsqueeze(2) # (cs, G, 1, TOPK)
score = torch.where(
valid_chg,
score,
torch.tensor(-1e30, device=score.device, dtype=score.dtype),
)
# prob = exp(score - lse * ln(2))
prob = torch.exp(score - lse_c.unsqueeze(-1) * ln2)
prob = torch.where(valid_chg, prob, torch.zeros_like(prob))
del score

# dp_raw = do @ V^T
dp_raw = torch.einsum("cghd,cgtd->cght", do_c, v_c)
dp = prob * (dp_raw - delta_c.unsqueeze(-1)) * scaling
dp = torch.where(valid_chg, dp, torch.zeros_like(dp))
del dp_raw

# dq[cs, g, hg, d] = sum_t dp * K
dq_c = torch.einsum("cght,cgtd->cghd", dp, k_c).view(cs, H, D_full)
dq[s0:s1] = dq_c
del dq_c

# dkv scatter per group
for g in range(G):
idx_g = safe_c[:, g, :].reshape(-1) # (cs*TOPK,)
# K contribution: (cs, Hg, TOPK) dp @ q
dp_g = dp[:, g, :, :] # (cs, Hg, TOPK)
q_g = q_c[:, g, :, :] # (cs, Hg, D_full)
prob_g = prob[:, g, :, :] # (cs, Hg, TOPK)
do_g = do_c[:, g, :, :] # (cs, Hg, D_v)
valid_gm = valid_c[:, g, :].unsqueeze(-1) # (cs, TOPK, 1)
k_contrib = torch.einsum("cht,chd->ctd", dp_g, q_g) # (cs, TOPK, D_full)
v_contrib = torch.einsum("cht,chd->ctd", prob_g, do_g) # (cs, TOPK, D_v)
k_contrib = torch.where(valid_gm, k_contrib, torch.zeros_like(k_contrib))
v_contrib = torch.where(valid_gm, v_contrib, torch.zeros_like(v_contrib))
dkv_fp32[:, g, :].index_add_(0, idx_g, k_contrib.reshape(-1, D_full))
dkv_fp32[:, g, :D_v].index_add_(0, idx_g, v_contrib.reshape(-1, D_v))
del k_contrib, v_contrib
del prob, dp, k_c, v_c

return dq.to(q.dtype).contiguous(), dkv_fp32.contiguous()


class SparseMLA(torch.autograd.Function):
@staticmethod
def forward(ctx, q, kv, indices, scaling):
"""
Args:
q: Query tensor (seq_len, heads, dim_plus_tail_dim)
kv: Key-Value tensor (seq_len_kv, kv_group, dim_plus_tail_dim)
indices: Sparse indices tensor (seq_len, kv_group, topk)

Returns:
out: Output tensor (seq_len, heads, dim)
"""
indices = indices.contiguous()
q, kv = q.contiguous(), kv.contiguous()
ctx.scaling = scaling
tl_out, tl_lse = sparse_mla_fwd_interface(q, kv, indices, sm_scale=scaling)

# Save tensors for backward pass
ctx.save_for_backward(q, kv, indices, tl_out, tl_lse)

return tl_out, tl_lse

@staticmethod
def backward(ctx, grad_output, grad_lse):
"""
Args:
grad_output: Gradient of the loss with respect to output

Returns:
Gradients for q, kv, and indices (None for indices)
"""
q, kv, indices, tl_out, tl_lse = ctx.saved_tensors
scaling = ctx.scaling

# G11r: pure-pytorch fallback when AREAL_SPARSE_MLA_PYTORCH_BWD=1.
# Used to bypass the L20X TileLang bwd kernel NaN bug while we figure
# out the kernel cache invalidation problem.
if os.environ.get("AREAL_SPARSE_MLA_PYTORCH_BWD", "0") == "1":
tl_dq, tl_dkv = _pytorch_sparse_mla_bwd(
q, kv, tl_out, grad_output, indices, tl_lse, scaling
)
else:
tl_dq, tl_dkv = sparse_mla_bwd(
q,
kv,
tl_out,
grad_output.contiguous(),
indices,
tl_lse,
sm_scale=scaling,
)

# Return gradients for each input (None for indices as it's not differentiable)
return tl_dq, tl_dkv, None, None
Loading
Loading