Skip to content

80b changes#61

Draft
rob-maron wants to merge 2 commits intodev-updated-againfrom
rm/80b-2
Draft

80b changes#61
rob-maron wants to merge 2 commits intodev-updated-againfrom
rm/80b-2

Conversation

@rob-maron
Copy link
Copy Markdown

No description provided.

Comment on lines +47 to +64


from causal_conv1d import causal_conv1d_fn as _cuda_causal_conv1d_fn


def causal_conv1d_fn(
x: torch.Tensor,
weight: torch.Tensor,
bias: torch.Tensor | None = None,
activation: str | None = None,
seq_idx: torch.Tensor | None = None,
**kwargs,
) -> tuple[torch.Tensor, None]:
# FLA convention is [B, T, D]; causal_conv1d expects [B, D, T]
out = _cuda_causal_conv1d_fn(
x.transpose(1, 2), weight, bias=bias, seq_idx=seq_idx, activation=activation
)
return out.transpose(1, 2), None
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[when launching a multinode training run using submit_sft.py]:

With FLA's causal_conv1d layer I was seeing a few issues:

  1. CUDA illegal memory access in fla/modules/conv/triton/ops.py during the backward pass
  2. torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.

I replaced it the layer with one from causal_conv1d, which fixes both issues. Lmk if there is a preferred other way around it

nodes_array=($nodes)
head_node=${nodes_array[0]}
export head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address)
export head_node_ip=$(getent hosts "$head_node" | awk '{print $1}')
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The srun here was deadlocking, this change makes it resolve locally

Comment on lines 224 to +293
@@ -239,11 +239,7 @@ def _to_hf_qwen3next(self, state_dict: dict[str, Any]) -> dict[str, Any]:
)
hf_state_dict.update(local_expert_fqn)
else:
n_experts = (
self.model_args.moe_args.num_experts
if "shared" not in key
else self.model_args.moe_args.num_shared_experts
)
n_experts = self.model_args.moe_args.num_experts
split_values = self._split_experts_weights(value, n_experts)
for expert_num in range(n_experts):
new_key = new_abstract_key.format(layer_num, expert_num)
@@ -268,15 +264,9 @@ def _from_hf_qwen3next(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
expert_weights_by_layer: dict[str, dict[str, dict]] = {}

for key, value in hf_state_dict.items():
if "mlp.experts." in key or "mlp.shared_expert." in key:
abstract_key = re.sub(
r"(\d+)", "{}", key, count=2 if "experts" in key else 1
)
if "experts" in key:
layer_num, expert_num = re.findall(r"\d+", key)
else:
layer_num = re.search(r"\d+", key).group(0)
expert_num = None
if "mlp.experts." in key and "mlp.shared_expert." not in key:
abstract_key = re.sub(r"(\d+)", "{}", key, count=2)
layer_num, expert_num = re.findall(r"\d+", key)
titan_abstract_key = self.from_hf_map[abstract_key]
new_key = titan_abstract_key.format(layer_num)

@@ -285,7 +275,7 @@ def _from_hf_qwen3next(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
if titan_abstract_key not in expert_weights_by_layer[layer_num]:
expert_weights_by_layer[layer_num][titan_abstract_key] = {}
expert_weights_by_layer[layer_num][titan_abstract_key][
expert_num
int(expert_num)
] = value

if isinstance(value, DTensor):
@@ -296,16 +286,11 @@ def _from_hf_qwen3next(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]:
value.device_mesh,
)
else:
n_experts = (
self.model_args.moe_args.num_experts
if "experts" in key
else self.model_args.moe_args.num_shared_experts
)
stacked_value = self._concatenate_expert_weights(
expert_weights_by_layer,
titan_abstract_key,
layer_num,
n_experts,
self.model_args.moe_args.num_experts,
Copy link
Copy Markdown
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The shared experts were in the grouped expert path (which expects 3d tensors), causing IndexError: too many indices for tensor of dimension 2. This takes them out

@dmahan93 dmahan93 requested a review from jquesnelle March 18, 2026 02:15
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should move these to a cluster specific setup as your settings conflict with the B200 cluster, maybe some host name based .sh routing?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants