Checklist
Describe the bug
When training with dflash, I encountered this error:
[rank3]: Traceback (most recent call last):
[rank3]: File "/dc-hl/zibing.wei/code/dflash/SpecForge/scripts/train_dflash.py", line 548, in
[rank3]: main()
[rank3]: File "/dc-hl/zibing.wei/code/dflash/SpecForge/scripts/train_dflash.py", line 493, in main
[rank3]: loss, accuracy = dflash_model(
[rank3]: ^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 851, in forward
[rank3]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/dc-hl/zibing.wei/code/dflash/SpecForge/specforge/core/dflash.py", line 208, in forward
[rank3]: dflash_attn_mask = create_dflash_block_mask(
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/dc-hl/zibing.wei/code/dflash/SpecForge/specforge/core/dflash.py", line 61, in create_dflash_block_mask
[rank3]: return create_block_mask(
[rank3]: ^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/attention/flex_attention.py", line 1091, in create_block_mask
[rank3]: mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/attention/flex_attention.py", line 1021, in create_mask
[rank3]: mask = mask_mod(b, h, m, n)
[rank3]: ^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank3]: return vmap_impl(
[rank3]: ^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank3]: return _flat_vmap(
[rank3]: ^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank3]: batched_outputs = func(*batched_inputs, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank3]: return vmap_impl(
[rank3]: ^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank3]: return _flat_vmap(
[rank3]: ^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank3]: batched_outputs = func(*batched_inputs, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank3]: return vmap_impl(
[rank3]: ^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank3]: return _flat_vmap(
[rank3]: ^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank3]: batched_outputs = func(*batched_inputs, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank3]: return vmap_impl(
[rank3]: ^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank3]: return _flat_vmap(
[rank3]: ^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank3]: batched_outputs = func(*batched_inputs, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/dc-hl/zibing.wei/code/dflash/SpecForge/specforge/core/dflash.py", line 42, in dflash_mask_mod
[rank3]: q_block_id = q_idx // block_size
[rank3]: ~~~~~~^^~~~~~~~~~~~
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 44, in wrapped
[rank3]: return handle_torch_function(wrapped, sargs, *sargs, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/overrides.py", line 1728, in handle_torch_function
[rank3]: result = mode.torch_function(public_api, types, args, kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 146, in torch_function
[rank3]: return func(*args, **(kwargs or {}))
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 45, in wrapped
[rank3]: return f(self, *args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 1119, in floordiv
[rank3]: return torch.floor_divide(self, other)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: torch.AcceleratorError: CUDA error: device-side assert triggered
[rank3]: Search for cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information. [rank3]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. [rank3]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1 [rank3]: Compile with TORCH_USE_CUDA_DSA` to enable device-side assertions.
Reproduction
#!/bin/bash
source /workspace/SpecForge/.venv/bin/activate
uv pip uninstall specforge
cd /dc-hl/zibing.wei/code/dflash/SpecForge
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels
export SPECFORGE_DATA_NUM_PROC=64
NUM_GPUS=8
export PYTHONWARNINGS="ignore"
TRAIN_DATA_PATH=
TARGET_MODEL_NAME_OR_PATH=/Qwen/Qwen3-8B
TRAIN_DATA_PATH=
MODEL_MAX_LENGTH=20000
WANDB_API_KEY=""
PROJECT_NAME=dflash_online
TP_SIZE=2
BATCH_SIZE=2
MODEL_NAME=qwen3_8b
TRAIN_ONLY_LAST_TURN=False
EXPERIMENT_NAME=dflash-${MODEL_NAME}-tp_size-${TP_SIZE}-batch_szie-${BATCH_SIZE}-TRAIN_ONLY_LAST_TURN${TRAIN_ONLY_LAST_TURN}-test
OUTPUT_DIR=/dc-hl/zibing.wei/output/onestepmodel_draftmodel/$EXPERIMENT_NAME
ATTENTION_BACKEND=${2:-flex_attention}
torchrun
--standalone
--nproc_per_node $NUM_GPUS
$ROOT_DIR/scripts/train_dflash.py
--target-model-path $TARGET_MODEL_NAME_OR_PATH
--draft-config-path $ROOT_DIR/configs/qwen3-8b-dflash.json
--train-data-path $TRAIN_DATA_PATH
--output-dir $OUTPUT_DIR
--num-epochs 6
--batch-size $BATCH_SIZE
--tp-size $TP_SIZE
--learning-rate 6e-4
--warmup-ratio 0.04
--max-grad-norm 1.0
--max-length $MODEL_MAX_LENGTH
--chat-template qwen
--attention-backend $ATTENTION_BACKEND
--loss-decay-gamma 7.0
--log-interval 50
--save-interval 1000
--report-to wandb
--wandb-project $PROJECT_NAME
--wandb-name $EXPERIMENT_NAME
--wandb-key $WANDB_API_KEY
--wandb-project specforge-qwen3-8b-dflash
--target-model-backend sglang
--sglang-mem-fraction-static 0.2
--train-only-last-turn $TRAIN_ONLY_LAST_TURN
--block-size 4
--wandb-name qwen3-8b-dflash-perfectblend
Environment
environment installed using requirements
Checklist
Describe the bug
When training with dflash, I encountered this error:
[rank3]: Traceback (most recent call last):
[rank3]: File "/dc-hl/zibing.wei/code/dflash/SpecForge/scripts/train_dflash.py", line 548, in
[rank3]: main()
[rank3]: File "/dc-hl/zibing.wei/code/dflash/SpecForge/scripts/train_dflash.py", line 493, in main
[rank3]: loss, accuracy = dflash_model(
[rank3]: ^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 851, in forward
[rank3]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/dc-hl/zibing.wei/code/dflash/SpecForge/specforge/core/dflash.py", line 208, in forward
[rank3]: dflash_attn_mask = create_dflash_block_mask(
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/dc-hl/zibing.wei/code/dflash/SpecForge/specforge/core/dflash.py", line 61, in create_dflash_block_mask
[rank3]: return create_block_mask(
[rank3]: ^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/attention/flex_attention.py", line 1091, in create_block_mask
[rank3]: mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/attention/flex_attention.py", line 1021, in create_mask
[rank3]: mask = mask_mod(b, h, m, n)
[rank3]: ^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank3]: return vmap_impl(
[rank3]: ^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank3]: return _flat_vmap(
[rank3]: ^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank3]: batched_outputs = func(*batched_inputs, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank3]: return vmap_impl(
[rank3]: ^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank3]: return _flat_vmap(
[rank3]: ^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank3]: batched_outputs = func(*batched_inputs, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank3]: return vmap_impl(
[rank3]: ^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank3]: return _flat_vmap(
[rank3]: ^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank3]: batched_outputs = func(*batched_inputs, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank3]: return vmap_impl(
[rank3]: ^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank3]: return _flat_vmap(
[rank3]: ^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank3]: batched_outputs = func(*batched_inputs, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/dc-hl/zibing.wei/code/dflash/SpecForge/specforge/core/dflash.py", line 42, in dflash_mask_mod
[rank3]: q_block_id = q_idx // block_size
[rank3]: ~~~~~~^^~~~~~~~~~~~
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 44, in wrapped
[rank3]: return handle_torch_function(wrapped, sargs, *sargs, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/overrides.py", line 1728, in handle_torch_function
[rank3]: result = mode.torch_function(public_api, types, args, kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 146, in torch_function
[rank3]: return func(*args, **(kwargs or {}))
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 45, in wrapped
[rank3]: return f(self, *args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 1119, in floordiv
[rank3]: return torch.floor_divide(self, other)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: torch.AcceleratorError: CUDA error: device-side assert triggered
[rank3]: Search for
cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information. [rank3]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. [rank3]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1 [rank3]: Compile withTORCH_USE_CUDA_DSA` to enable device-side assertions.Reproduction
#!/bin/bash
source /workspace/SpecForge/.venv/bin/activate
uv pip uninstall specforge
cd /dc-hl/zibing.wei/code/dflash/SpecForge
SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels
export SPECFORGE_DATA_NUM_PROC=64
NUM_GPUS=8
export PYTHONWARNINGS="ignore"
TRAIN_DATA_PATH=
TARGET_MODEL_NAME_OR_PATH=/Qwen/Qwen3-8B
TRAIN_DATA_PATH=
MODEL_MAX_LENGTH=20000
WANDB_API_KEY=""
PROJECT_NAME=dflash_online
TP_SIZE=2
BATCH_SIZE=2
MODEL_NAME=qwen3_8b
TRAIN_ONLY_LAST_TURN=False
EXPERIMENT_NAME=dflash-${MODEL_NAME}-tp_size-${TP_SIZE}-batch_szie-${BATCH_SIZE}-TRAIN_ONLY_LAST_TURN${TRAIN_ONLY_LAST_TURN}-test
OUTPUT_DIR=/dc-hl/zibing.wei/output/onestepmodel_draftmodel/$EXPERIMENT_NAME
ATTENTION_BACKEND=${2:-flex_attention}
torchrun
--standalone
--nproc_per_node $NUM_GPUS
$ROOT_DIR/scripts/train_dflash.py
--target-model-path $TARGET_MODEL_NAME_OR_PATH
--draft-config-path $ROOT_DIR/configs/qwen3-8b-dflash.json
--train-data-path $TRAIN_DATA_PATH
--output-dir $OUTPUT_DIR
--num-epochs 6
--batch-size $BATCH_SIZE
--tp-size $TP_SIZE
--learning-rate 6e-4
--warmup-ratio 0.04
--max-grad-norm 1.0
--max-length $MODEL_MAX_LENGTH
--chat-template qwen
--attention-backend $ATTENTION_BACKEND
--loss-decay-gamma 7.0
--log-interval 50
--save-interval 1000
--report-to wandb
--wandb-project $PROJECT_NAME
--wandb-name $EXPERIMENT_NAME
--wandb-key $WANDB_API_KEY
--wandb-project specforge-qwen3-8b-dflash
--target-model-backend sglang
--sglang-mem-fraction-static 0.2
--train-only-last-turn $TRAIN_ONLY_LAST_TURN
--block-size 4
--wandb-name qwen3-8b-dflash-perfectblend
Environment
environment installed using requirements