Checklist
Describe the bug
only obtained the results from one layer of the target model
[rank5]: RuntimeError: mat1 and mat2 shapes cannot be multiplied (5900x2048 and 10240x2048)
[rank2]: Traceback (most recent call last):
[rank2]: File "/dc-hl/zibing.wei/code/dflash/qwen3.5/SpecForge/scripts/train_dflash.py", line 592, in
[rank2]: main()
[rank2]: File "/dc-hl/zibing.wei/code/dflash/qwen3.5/SpecForge/scripts/train_dflash.py", line 534, in main
[rank2]: loss, accuracy = dflash_model(
[rank2]: ^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 851, in forward
[rank2]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/dc-hl/zibing.wei/code/dflash/qwen3.5/SpecForge/specforge/core/dflash.py", line 218, in forward
[rank2]: output_hidden = self.draft_model(
[rank2]: ^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/dc-hl/zibing.wei/code/dflash/qwen3.5/SpecForge/specforge/modeling/draft/dflash.py", line 253, in forward
[rank2]: target_hidden = self.hidden_norm(self.fc(target_hidden))
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 134, in forward
[rank2]: return F.linear(input, self.weight, self.bias)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: RuntimeError: mat1 and mat2 shapes cannot be multiplied (6219x2048 and 10240x2048)
Reproduction
qwen3.5-35b-a3b
Environment
qwen3.5-35b-a3b
Checklist
Describe the bug
only obtained the results from one layer of the target model
[rank5]: RuntimeError: mat1 and mat2 shapes cannot be multiplied (5900x2048 and 10240x2048)
[rank2]: Traceback (most recent call last):
[rank2]: File "/dc-hl/zibing.wei/code/dflash/qwen3.5/SpecForge/scripts/train_dflash.py", line 592, in
[rank2]: main()
[rank2]: File "/dc-hl/zibing.wei/code/dflash/qwen3.5/SpecForge/scripts/train_dflash.py", line 534, in main
[rank2]: loss, accuracy = dflash_model(
[rank2]: ^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 851, in forward
[rank2]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/dc-hl/zibing.wei/code/dflash/qwen3.5/SpecForge/specforge/core/dflash.py", line 218, in forward
[rank2]: output_hidden = self.draft_model(
[rank2]: ^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/dc-hl/zibing.wei/code/dflash/qwen3.5/SpecForge/specforge/modeling/draft/dflash.py", line 253, in forward
[rank2]: target_hidden = self.hidden_norm(self.fc(target_hidden))
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank2]: return self._call_impl(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank2]: return forward_call(*args, **kwargs)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 134, in forward
[rank2]: return F.linear(input, self.weight, self.bias)
[rank2]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank2]: RuntimeError: mat1 and mat2 shapes cannot be multiplied (6219x2048 and 10240x2048)
Reproduction
qwen3.5-35b-a3b
Environment
qwen3.5-35b-a3b