Conversation
|
|
||
|
|
||
| from causal_conv1d import causal_conv1d_fn as _cuda_causal_conv1d_fn | ||
|
|
||
|
|
||
| def causal_conv1d_fn( | ||
| x: torch.Tensor, | ||
| weight: torch.Tensor, | ||
| bias: torch.Tensor | None = None, | ||
| activation: str | None = None, | ||
| seq_idx: torch.Tensor | None = None, | ||
| **kwargs, | ||
| ) -> tuple[torch.Tensor, None]: | ||
| # FLA convention is [B, T, D]; causal_conv1d expects [B, D, T] | ||
| out = _cuda_causal_conv1d_fn( | ||
| x.transpose(1, 2), weight, bias=bias, seq_idx=seq_idx, activation=activation | ||
| ) | ||
| return out.transpose(1, 2), None |
There was a problem hiding this comment.
[when launching a multinode training run using submit_sft.py]:
With FLA's causal_conv1d layer I was seeing a few issues:
- CUDA illegal memory access in
fla/modules/conv/triton/ops.pyduring the backward pass torch.utils.checkpoint.CheckpointError: torch.utils.checkpoint: Recomputed values for the following tensors have different metadata than during the forward pass.
I replaced it the layer with one from causal_conv1d, which fixes both issues. Lmk if there is a preferred other way around it
| nodes_array=($nodes) | ||
| head_node=${nodes_array[0]} | ||
| export head_node_ip=$(srun --nodes=1 --ntasks=1 -w "$head_node" hostname --ip-address) | ||
| export head_node_ip=$(getent hosts "$head_node" | awk '{print $1}') |
There was a problem hiding this comment.
The srun here was deadlocking, this change makes it resolve locally
| @@ -239,11 +239,7 @@ def _to_hf_qwen3next(self, state_dict: dict[str, Any]) -> dict[str, Any]: | |||
| ) | |||
| hf_state_dict.update(local_expert_fqn) | |||
| else: | |||
| n_experts = ( | |||
| self.model_args.moe_args.num_experts | |||
| if "shared" not in key | |||
| else self.model_args.moe_args.num_shared_experts | |||
| ) | |||
| n_experts = self.model_args.moe_args.num_experts | |||
| split_values = self._split_experts_weights(value, n_experts) | |||
| for expert_num in range(n_experts): | |||
| new_key = new_abstract_key.format(layer_num, expert_num) | |||
| @@ -268,15 +264,9 @@ def _from_hf_qwen3next(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: | |||
| expert_weights_by_layer: dict[str, dict[str, dict]] = {} | |||
|
|
|||
| for key, value in hf_state_dict.items(): | |||
| if "mlp.experts." in key or "mlp.shared_expert." in key: | |||
| abstract_key = re.sub( | |||
| r"(\d+)", "{}", key, count=2 if "experts" in key else 1 | |||
| ) | |||
| if "experts" in key: | |||
| layer_num, expert_num = re.findall(r"\d+", key) | |||
| else: | |||
| layer_num = re.search(r"\d+", key).group(0) | |||
| expert_num = None | |||
| if "mlp.experts." in key and "mlp.shared_expert." not in key: | |||
| abstract_key = re.sub(r"(\d+)", "{}", key, count=2) | |||
| layer_num, expert_num = re.findall(r"\d+", key) | |||
| titan_abstract_key = self.from_hf_map[abstract_key] | |||
| new_key = titan_abstract_key.format(layer_num) | |||
|
|
|||
| @@ -285,7 +275,7 @@ def _from_hf_qwen3next(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: | |||
| if titan_abstract_key not in expert_weights_by_layer[layer_num]: | |||
| expert_weights_by_layer[layer_num][titan_abstract_key] = {} | |||
| expert_weights_by_layer[layer_num][titan_abstract_key][ | |||
| expert_num | |||
| int(expert_num) | |||
| ] = value | |||
|
|
|||
| if isinstance(value, DTensor): | |||
| @@ -296,16 +286,11 @@ def _from_hf_qwen3next(self, hf_state_dict: dict[str, Any]) -> dict[str, Any]: | |||
| value.device_mesh, | |||
| ) | |||
| else: | |||
| n_experts = ( | |||
| self.model_args.moe_args.num_experts | |||
| if "experts" in key | |||
| else self.model_args.moe_args.num_shared_experts | |||
| ) | |||
| stacked_value = self._concatenate_expert_weights( | |||
| expert_weights_by_layer, | |||
| titan_abstract_key, | |||
| layer_num, | |||
| n_experts, | |||
| self.model_args.moe_args.num_experts, | |||
There was a problem hiding this comment.
The shared experts were in the grouped expert path (which expects 3d tensors), causing IndexError: too many indices for tensor of dimension 2. This takes them out
There was a problem hiding this comment.
we should move these to a cluster specific setup as your settings conflict with the B200 cluster, maybe some host name based .sh routing?
No description provided.