Skip to content

[Bug] Dflash #501

@Jim2016713

Description

@Jim2016713

Checklist

  • 1. I have searched related issues but cannot get the expected help.
  • 2. The bug has not been fixed in the latest version.
  • 3. Please note that if the bug-related issue you submitted lacks corresponding environment info and a minimal reproducible demo, it will be challenging for us to reproduce and resolve the issue, reducing the likelihood of receiving feedback.
  • 4. If the issue you raised is not a bug but a question, please raise a discussion at https://github.com/sgl-project/SpecForge/discussions/new/choose Otherwise, it will be closed.
  • 5. Please use English, otherwise it will be closed.

Describe the bug

When training with dflash, I encountered this error:
[rank3]: Traceback (most recent call last):
[rank3]: File "/dc-hl/zibing.wei/code/dflash/SpecForge/scripts/train_dflash.py", line 548, in
[rank3]: main()
[rank3]: File "/dc-hl/zibing.wei/code/dflash/SpecForge/scripts/train_dflash.py", line 493, in main
[rank3]: loss, accuracy = dflash_model(
[rank3]: ^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/distributed/fsdp/fully_sharded_data_parallel.py", line 851, in forward
[rank3]: output = self._fsdp_wrapped_module(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1775, in _wrapped_call_impl
[rank3]: return self._call_impl(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1786, in _call_impl
[rank3]: return forward_call(*args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/dc-hl/zibing.wei/code/dflash/SpecForge/specforge/core/dflash.py", line 208, in forward
[rank3]: dflash_attn_mask = create_dflash_block_mask(
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/dc-hl/zibing.wei/code/dflash/SpecForge/specforge/core/dflash.py", line 61, in create_dflash_block_mask
[rank3]: return create_block_mask(
[rank3]: ^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/attention/flex_attention.py", line 1091, in create_block_mask
[rank3]: mask_tensor = create_mask(mask_mod, B, H, Q_LEN, KV_LEN, device)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/nn/attention/flex_attention.py", line 1021, in create_mask
[rank3]: mask = mask_mod(b, h, m, n)
[rank3]: ^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank3]: return vmap_impl(
[rank3]: ^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank3]: return _flat_vmap(
[rank3]: ^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank3]: batched_outputs = func(*batched_inputs, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank3]: return vmap_impl(
[rank3]: ^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank3]: return _flat_vmap(
[rank3]: ^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank3]: batched_outputs = func(*batched_inputs, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank3]: return vmap_impl(
[rank3]: ^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank3]: return _flat_vmap(
[rank3]: ^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank3]: batched_outputs = func(*batched_inputs, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/apis.py", line 208, in wrapped
[rank3]: return vmap_impl(
[rank3]: ^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 282, in vmap_impl
[rank3]: return _flat_vmap(
[rank3]: ^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_functorch/vmap.py", line 432, in _flat_vmap
[rank3]: batched_outputs = func(*batched_inputs, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/dc-hl/zibing.wei/code/dflash/SpecForge/specforge/core/dflash.py", line 42, in dflash_mask_mod
[rank3]: q_block_id = q_idx // block_size
[rank3]: ~~~~~~^^~~~~~~~~~~~
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 44, in wrapped
[rank3]: return handle_torch_function(wrapped, sargs, *sargs, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/overrides.py", line 1728, in handle_torch_function
[rank3]: result = mode.torch_function(public_api, types, args, kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_dynamo/_trace_wrapped_higher_order_op.py", line 146, in torch_function
[rank3]: return func(*args, **(kwargs or {}))
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 45, in wrapped
[rank3]: return f(self, *args, **kwargs)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: File "/workspace/SpecForge/.venv/lib/python3.11/site-packages/torch/_tensor.py", line 1119, in floordiv
[rank3]: return torch.floor_divide(self, other)
[rank3]: ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank3]: torch.AcceleratorError: CUDA error: device-side assert triggered
[rank3]: Search for cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information. [rank3]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect. [rank3]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1 [rank3]: Compile with TORCH_USE_CUDA_DSA` to enable device-side assertions.

Reproduction

#!/bin/bash
source /workspace/SpecForge/.venv/bin/activate
uv pip uninstall specforge
cd /dc-hl/zibing.wei/code/dflash/SpecForge

SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd )
ROOT_DIR=$(dirname $SCRIPT_DIR)
export TORCHINDUCTOR_CACHE_DIR=$ROOT_DIR/cache/compiled_kernels
export SPECFORGE_DATA_NUM_PROC=64
NUM_GPUS=8
export PYTHONWARNINGS="ignore"

TRAIN_DATA_PATH=

TARGET_MODEL_NAME_OR_PATH=/Qwen/Qwen3-8B
TRAIN_DATA_PATH=

MODEL_MAX_LENGTH=20000

WANDB_API_KEY=""
PROJECT_NAME=dflash_online
TP_SIZE=2
BATCH_SIZE=2
MODEL_NAME=qwen3_8b
TRAIN_ONLY_LAST_TURN=False
EXPERIMENT_NAME=dflash-${MODEL_NAME}-tp_size-${TP_SIZE}-batch_szie-${BATCH_SIZE}-TRAIN_ONLY_LAST_TURN${TRAIN_ONLY_LAST_TURN}-test
OUTPUT_DIR=/dc-hl/zibing.wei/output/onestepmodel_draftmodel/$EXPERIMENT_NAME

ATTENTION_BACKEND=${2:-flex_attention}

torchrun
--standalone
--nproc_per_node $NUM_GPUS
$ROOT_DIR/scripts/train_dflash.py
--target-model-path $TARGET_MODEL_NAME_OR_PATH
--draft-config-path $ROOT_DIR/configs/qwen3-8b-dflash.json
--train-data-path $TRAIN_DATA_PATH
--output-dir $OUTPUT_DIR
--num-epochs 6
--batch-size $BATCH_SIZE
--tp-size $TP_SIZE
--learning-rate 6e-4
--warmup-ratio 0.04
--max-grad-norm 1.0
--max-length $MODEL_MAX_LENGTH
--chat-template qwen
--attention-backend $ATTENTION_BACKEND
--loss-decay-gamma 7.0
--log-interval 50
--save-interval 1000
--report-to wandb
--wandb-project $PROJECT_NAME
--wandb-name $EXPERIMENT_NAME
--wandb-key $WANDB_API_KEY
--wandb-project specforge-qwen3-8b-dflash
--target-model-backend sglang
--sglang-mem-fraction-static 0.2
--train-only-last-turn $TRAIN_ONLY_LAST_TURN
--block-size 4
--wandb-name qwen3-8b-dflash-perfectblend

Environment

environment installed using requirements

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions