Skip to content

Problem trying to get vLLM + Step3 + AFD running #43

@orionorin

Description

@orionorin

Hi StepMesh team,

We are recently experimenting with AFD models, but failed to get the Step3 AFD PR in the vLLM repo (draft AFD implementation for step3 by Oliver-ss · Pull Request #25162 · vllm-project/vllm) running, and we would like to get some suggestions on configuration and environment setup.

Environment

Hardware

We are using 2x8 H20 servers. Each server has 9 dual-port RNICs with bond, among which bond2~bond9 supports RDMA.

FFN Server:

bond8            UP             1.2.3.176/25
bond2            UP             1.2.3.219/25
bond7            UP             1.2.3.87/25
bond9            UP             1.2.3.93/25
bond3            UP             1.2.3.9/25
bond4            UP             1.2.3.162/25
bond5            UP             1.2.3.67/25
bond1            UP             1.2.4.41/26
bond6            UP             1.2.3.205/25

Attn Server:

bond8            UP             1.2.3.211/25
bond2            UP             1.2.3.180/25
bond7            UP             1.2.3.84/25
bond9            UP             1.2.3.66/25
bond3            UP             1.2.3.6/25
bond4            UP             1.2.3.217/25
bond1            UP             1.2.4.18/26
bond5            UP             1.2.3.63/25
bond6            UP             1.2.3.202/25

Software

ATTN DP + FFN TP

When we ran Step3 (FP8) with Attn DP8 + FFN TP8 disaggregation, vLLM would crash during StepMesh scheduler initialization. The steps are:

Start FFN server:

vllm fserver /data1/step3-fp8 \
        -tp 8 \
        --afd-config '{"afd_connector": "stepmesh", "afd_role": "ffn", "afd_host": "1.2.3.219"}' \
        --max-num-batched-tokens 384 --max-num-seqs 384 \
        --compilation-config '{"cudagraph_mode": "PIECEWISE", "cudagraph_capture_sizes": [1, 8]}'

Start Attn server:

vllm serve /data1/step3-fp8 \
        -dp 8 \
        --served-model-name step3 \
        --afd-config '{"afd_connector": "stepmesh", "afd_role": "attention", "afd_host": "1.2.3.219"}' \
        --max-num-batched-tokens 384 --max-num-seqs 384 \
        --compilation-config '{"cudagraph_mode": "PIECEWISE", "cudagraph_capture_sizes": [1, 8]}'

What we see is each of the 8 Attention servers would send an ADD_NODE with aux_id=0, resulting in a collision in the StepMesh scheduler:

[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/././rdma_transport.h:413: SendRendezvousBegin 
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:703: rdma 32767     sent: ? => 1. Meta: request=0, timestamp=0, control={ cmd=ADD_NODE, node={ [role=worker, ip=1.2.3.63, port=48821, is_recovery=0, aux_id=0, num_ports=1] } }. NOT DATA MSG!, Slave QP Count: 0
bind to DMLC_NODE_HOST: 1.2.3.217
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:1099: 32767 OnConnected to 1
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:557: automatic detect interface and ip from gpu(0): bond3 (1.2.3.6)
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:153: bind to DMLC_NODE_HOST: 1.2.3.6
BindToCpuCore: gpu 0 -> cpu 11
BindToCpuCore: gpu 0 -> cpu 12
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/././rdma_transport.h:198: qp created: pd=0x7f3730000fc0 , cq=0x7f37300010d0, qp=37482
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:1099: 32767 OnConnected to 1
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:628: rdma 32767     gonna notify scheduler of myself: ? => 1. Meta: request=0, timestamp=0, control={ cmd=ADD_NODE, node={ [role=worker, ip=1.2.3.84, port=53355, is_recovery=0, aux_id=0, num_ports=1] } }. NOT DATA MSG!, Slave QP Count: 0
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/././rdma_transport.h:413: SendRendezvousBegin 
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:703: rdma 32767     sent: ? => 1. Meta: request=0, timestamp=0, control={ cmd=ADD_NODE, node={ [role=worker, ip=1.2.3.84, port=53355, is_recovery=0, aux_id=0, num_ports=1] } }. NOT DATA MSG!, Slave QP Count: 0
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:1099: 32767 OnConnected to 1
BindToCpuCore: gpu 0 -> cpu 11
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:1099: 32767 OnConnected to 1
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:628: rdma 32767     gonna notify scheduler of myself: ? => 1. Meta: request=0, timestamp=0, control={ cmd=ADD_NODE, node={ [role=worker, ip=1.2.3.66, port=45585, is_recovery=0, aux_id=0, num_ports=1] } }. NOT DATA MSG!, Slave QP Count: 0
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/././rdma_transport.h:413: SendRendezvousBegin 
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:703: rdma 32767     sent: ? => 1. Meta: request=0, timestamp=0, control={ cmd=ADD_NODE, node={ [role=worker, ip=1.2.3.66, port=45585, is_recovery=0, aux_id=0, num_ports=1] } }. NOT DATA MSG!, Slave QP Count: 0
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:173:  bound to port 56187
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:601: Bind to [role=worker, ip=1.2.3.6, port=56187, is_recovery=0, aux_id=-1, num_ports=1]
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:180: Connecting to Node 1, My_Node=32767
BindToCpuCore: gpu 0 -> cpu 12
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:173:  bound to port 40665
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:601: Bind to [role=worker, ip=1.2.3.217, port=40665, is_recovery=0, aux_id=-1, num_ports=1]
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:180: Connecting to Node 1, My_Node=32767
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:1099: 32767 OnConnected to 1
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:628: rdma 32767     gonna notify scheduler of myself: ? => 1. Meta: request=0, timestamp=0, control={ cmd=ADD_NODE, node={ [role=worker, ip=1.2.3.211, port=57059, is_recovery=0, aux_id=0, num_ports=1] } }. NOT DATA MSG!, Slave QP Count: 0
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/././rdma_transport.h:413: SendRendezvousBegin 
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:703: rdma 32767     sent: ? => 1. Meta: request=0, timestamp=0, control={ cmd=ADD_NODE, node={ [role=worker, ip=1.2.3.211, port=57059, is_recovery=0, aux_id=0, num_ports=1] } }. NOT DATA MSG!, Slave QP Count: 0
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/././rdma_transport.h:198: qp created: pd=0x7fb72c001030 , cq=0x7fb72c001140, qp=19012
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/././rdma_transport.h:198: qp created: pd=0x7fdc1c001030 , cq=0x7fdc1c001140, qp=16317
BindToCpuCore: gpu 0 -> cpu 11
BindToCpuCore: gpu 0 -> cpu 12
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:1099: 32767 OnConnected to 1
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:1099: 32767 OnConnected to 1
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:628: rdma 32767     gonna notify scheduler of myself: ? => 1. Meta: request=0, timestamp=0, control={ cmd=ADD_NODE, node={ [role=worker, ip=1.2.3.202, port=60397, is_recovery=0, aux_id=0, num_ports=1] } }. NOT DATA MSG!, Slave QP Count: 0
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/././rdma_transport.h:413: SendRendezvousBegin 
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:703: rdma 32767     sent: ? => 1. Meta: request=0, timestamp=0, control={ cmd=ADD_NODE, node={ [role=worker, ip=1.2.3.202, port=60397, is_recovery=0, aux_id=0, num_ports=1] } }. NOT DATA MSG!, Slave QP Count: 0
BindToCpuCore: gpu 0 -> cpu 11
BindToCpuCore: gpu 0 -> cpu 12
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/././rdma_transport.h:198: qp created: pd=0x7fb72c001030 , cq=0x7fb72c001140, qp=19013
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/././rdma_transport.h:198: qp created: pd=0x7fdc1c001030 , cq=0x7fdc1c001140, qp=16318
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:1099: 32767 OnConnected to 1
BindToCpuCore: gpu 0 -> cpu 11
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:1099: 32767 OnConnected to 1
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:628: rdma 32767     gonna notify scheduler of myself: ? => 1. Meta: request=0, timestamp=0, control={ cmd=ADD_NODE, node={ [role=worker, ip=1.2.3.180, port=42787, is_recovery=0, aux_id=0, num_ports=1] } }. NOT DATA MSG!, Slave QP Count: 0
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/././rdma_transport.h:413: SendRendezvousBegin 
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:703: rdma 32767     sent: ? => 1. Meta: request=0, timestamp=0, control={ cmd=ADD_NODE, node={ [role=worker, ip=1.2.3.180, port=42787, is_recovery=0, aux_id=0, num_ports=1] } }. NOT DATA MSG!, Slave QP Count: 0
BindToCpuCore: gpu 0 -> cpu 12
BindToCpuCore: gpu 0 -> cpu 11
BindToCpuCore: gpu 0 -> cpu 12
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:1099: 32767 OnConnected to 1
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:1099: 32767 OnConnected to 1
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:628: rdma 32767     gonna notify scheduler of myself: ? => 1. Meta: request=0, timestamp=0, control={ cmd=ADD_NODE, node={ [role=worker, ip=1.2.3.6, port=56187, is_recovery=0, aux_id=0, num_ports=1] } }. NOT DATA MSG!, Slave QP Count: 0
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/././rdma_transport.h:413: SendRendezvousBegin 
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:703: rdma 32767     sent: ? => 1. Meta: request=0, timestamp=0, control={ cmd=ADD_NODE, node={ [role=worker, ip=1.2.3.6, port=56187, is_recovery=0, aux_id=0, num_ports=1] } }. NOT DATA MSG!, Slave QP Count: 0
BindToCpuCore: gpu 0 -> cpu 11
BindToCpuCore: gpu 0 -> cpu 12
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:1099: 32767 OnConnected to 1
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/./rdma_van.h:1099: 32767 OnConnected to 1
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:628: rdma 32767     gonna notify scheduler of myself: ? => 1. Meta: request=0, timestamp=0, control={ cmd=ADD_NODE, node={ [role=worker, ip=1.2.3.217, port=40665, is_recovery=0, aux_id=0, num_ports=1] } }. NOT DATA MSG!, Slave QP Count: 0
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/././rdma_transport.h:413: SendRendezvousBegin 
[14:11:52] worker 0 /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:703: rdma 32767     sent: ? => 1. Meta: request=0, timestamp=0, control={ cmd=ADD_NODE, node={ [role=worker, ip=1.2.3.217, port=40665, is_recovery=0, aux_id=0, num_ports=1] } }. NOT DATA MSG!, Slave QP Count: 0
[14:11:52] scheduler 0 /nfs_disk/XXXXXX/afd/StepMesh/include/dmlc/logging.h:301: [14:11:52] /nfs_disk/XXXXXX/afd/StepMesh/src/van.cc:226: Check failed: (worker_ranks.find(node.aux_id)) == (worker_ranks.end()) rank must be unique: [role=worker, ip=1.2.3.63, port=48821, is_recovery=0, aux_id=0, num_ports=1]

It seems to me this aux_id ultimately comes form gpu_model_runner.py in vLLM, and for some reason vLLM obtained an incorrect local_rank.

        # init AFD config
        self.afd_config = vllm_config.afd_config
        if self.afd_config and self.afd_config.afd_role == "attention":
            self.afd_connector = AFDConnectorFactory.create_connector(
                get_world_group().rank,
                get_world_group().local_rank, vllm_config)
            self.afd_connector.init_afd_connector()
            self.num_stages = self.afd_config.num_afd_stages

ATTN TP + FFN TP (If this is a supported config)

Dynamo Issue

vLLM could not start with default configuration. The error messages below seem to indicate the call to time.time in the AFD metadata creation code cannot be traced. So we use --enforce_eager to disable compilation in later runs.

(Worker_TP1 pid=25649) INFO 11-11 15:19:14 [gpu_model_runner.py:2470] Starting to load model /data1/step3-fp8...
(Worker_TP1 pid=25649) INFO 11-11 15:19:14 [gpu_model_runner.py:2502] Loading model from scratch...
(Worker_TP1 pid=25649) INFO 11-11 15:19:14 [layer.py:437] MultiHeadAttention attn_backend: _Backend.XFORMERS, use_upstream_fa: False
(Worker_TP1 pid=25649) INFO 11-11 15:19:15 [cuda.py:368] Using Flash Attention backend on V1 engine.
(Worker_TP1 pid=25649) WARNING 11-11 15:19:15 [fp8.py:574] CutlassBlockScaledGroupedGemm not supported on the current platform.
(Worker_TP1 pid=25649) INFO 11-11 15:19:38 [default_loader.py:268] Loading weights took 21.74 seconds
(Worker_TP1 pid=25649) INFO 11-11 15:19:39 [gpu_model_runner.py:2524] Model loading took 39.8007 GiB and 23.545146 seconds
(Worker_TP1 pid=25649) INFO 11-11 15:19:39 [gpu_model_runner.py:3186] Encoder cache will be initialized with a budget of 3164 tokens, and profiled with 1 image items of the maximum feature size.
(Worker_TP1 pid=25649) WARNING 11-11 15:19:39 [__init__.py:2179] The following intended overrides are not keyword args and will be dropped: {'truncation'}
(Worker_TP1 pid=25649) WARNING 11-11 15:19:39 [__init__.py:2179] The following intended overrides are not keyword args and will be dropped: {'truncation'}
(Worker_TP1 pid=25649) /usr/local/lib/python3.12/dist-packages/torch/_dynamo/variables/functions.py:1481: UserWarning: Dynamo does not know how to trace the builtin `time.time.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
(Worker_TP1 pid=25649) If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
(Worker_TP1 pid=25649) If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
(Worker_TP1 pid=25649)   torch._dynamo.utils.warn_once(explanation + "\n" + "\n".join(hints))
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] WorkerProc hit an exception.
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] Traceback (most recent call last):
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/v1/executor/multiproc_executor.py", line 666, in worker_busy_loop
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     output = func(*args, **kwargs)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]              ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     return func(*args, **kwargs)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/v1/worker/gpu_worker.py", line 270, in determine_available_memory
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     self.model_runner.profile_run()
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/v1/worker/gpu_model_runner.py", line 3217, in profile_run
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     = self._dummy_run(self.max_num_tokens, is_profile=True)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     return func(*args, **kwargs)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/v1/worker/gpu_model_runner.py", line 2995, in _dummy_run
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     outputs = self.model(
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]               ^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     return self._call_impl(*args, **kwargs)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     return forward_call(*args, **kwargs)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_vl.py", line 1065, in forward
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     hidden_states = self.language_model(input_ids,
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     return self._call_impl(*args, **kwargs)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     return forward_call(*args, **kwargs)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_text.py", line 609, in forward
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     hidden_states = self.model(input_ids, positions, intermediate_tensors,
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/compilation/decorators.py", line 305, in __call__
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     output = self.compiled_callable(*args, **kwargs)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 745, in compile_wrapper
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     raise e.with_traceback(None) from e.__cause__  # User compiler error
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   Explanation: Dynamo does not know how to trace the builtin `time.time.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   Hint: If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   Hint: If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] 
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   Developer debug context: module: time, qualname: time, skip reason: <missing reason>
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] 
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] 
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] from user code:
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]    File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_text.py", line 523, in forward
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     metadata = AFDConnectorMetadata.create_attention_metadata(
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/distributed/afd_transfer/afd_connector/metadata.py", line 72, in create_attention_metadata
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     timestamp=time.time())
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] 
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] 
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] Traceback (most recent call last):
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/v1/executor/multiproc_executor.py", line 666, in worker_busy_loop
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     output = func(*args, **kwargs)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]              ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     return func(*args, **kwargs)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/v1/worker/gpu_worker.py", line 270, in determine_available_memory
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     self.model_runner.profile_run()
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/v1/worker/gpu_model_runner.py", line 3217, in profile_run
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     = self._dummy_run(self.max_num_tokens, is_profile=True)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]       ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     return func(*args, **kwargs)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/v1/worker/gpu_model_runner.py", line 2995, in _dummy_run
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     outputs = self.model(
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]               ^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     return self._call_impl(*args, **kwargs)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     return forward_call(*args, **kwargs)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_vl.py", line 1065, in forward
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     hidden_states = self.language_model(input_ids,
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     return self._call_impl(*args, **kwargs)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     return forward_call(*args, **kwargs)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_text.py", line 609, in forward
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     hidden_states = self.model(input_ids, positions, intermediate_tensors,
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/compilation/decorators.py", line 305, in __call__
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     output = self.compiled_callable(*args, **kwargs)
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/_dynamo/eval_frame.py", line 745, in compile_wrapper
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     raise e.with_traceback(None) from e.__cause__  # User compiler error
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] torch._dynamo.exc.Unsupported: Attempted to call function marked as skipped
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   Explanation: Dynamo does not know how to trace the builtin `time.time.` This function is either a Python builtin (e.g. _warnings.warn) or a third-party C/C++ Python extension (perhaps created with pybind).
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   Hint: If it is a Python builtin, please file an issue on GitHub so the PyTorch team can add support for it and see the next case for a workaround.
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   Hint: If it is a third-party C/C++ Python extension, please either wrap it into a PyTorch-understood custom operator (see https://pytorch.org/tutorials/advanced/custom_ops_landing_page.html for more details) or, if it is traceable, use `torch.compiler.allow_in_graph`.
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] 
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   Developer debug context: module: time, qualname: time, skip reason: <missing reason>
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] 
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] 
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] from user code:
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]    File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_text.py", line 523, in forward
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     metadata = AFDConnectorMetadata.create_attention_metadata(
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/distributed/afd_transfer/afd_connector/metadata.py", line 72, in create_attention_metadata
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]     timestamp=time.time())
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] 
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] Set TORCHDYNAMO_VERBOSE=1 for the internal stack trace (please do this especially if you're reporting a bug to PyTorch). For even more developer context, set TORCH_LOGS="+dynamo"
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671] 
(Worker_TP1 pid=25649) ERROR 11-11 15:19:45 [multiproc_executor.py:671]

Inference Issue

With --enforce-eager we are able to bring up the servers:

vllm fserver /data1/step3-fp8 \
        -tp 8 \
        --afd-config '{"afd_connector": "stepmesh", "afd_role": "ffn", "afd_host": "1.2.3.219"}' \
        --max-num-batched-tokens 384 --max-num-seqs 384 \
        --enforce-eager
vllm serve /data1/step3-fp8 \
        -tp 8 \
        --served-model-name step3 \
        --afd-config '{"afd_connector": "stepmesh", "afd_role": "attention", "afd_host": "1.2.3.219"}' --max-num-batched-tokens 384 --max-num-seqs 384 \
        --enforce-eager

But after sending a request, we got an error:

(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671] WorkerProc hit an exception.
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671] Traceback (most recent call last):
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/v1/executor/multiproc_executor.py", line 666, in worker_busy_loop
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     output = func(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]              ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return func(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/v1/worker/gpu_worker.py", line 462, in execute_model
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     output = self.model_runner.execute_model(scheduler_output,
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return func(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/v1/worker/gpu_model_runner.py", line 2181, in execute_model
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     model_output = self.model(
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]                    ^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return self._call_impl(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return forward_call(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_vl.py", line 1065, in forward
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     hidden_states = self.language_model(input_ids,
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return self._call_impl(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return forward_call(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_text.py", line 609, in forward
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     hidden_states = self.model(input_ids, positions, intermediate_tensors,
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/compilation/decorators.py", line 223, in __call__
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return self.forward(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_text.py", line 519, in forward
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     layer.compute_attn_output(
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_text.py", line 383, in compute_attn_output
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return self._compute_attn_output(hidden_states, residual,
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_text.py", line 346, in _compute_attn_output
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     hidden_states = self.self_attn(positions=positions,
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return self._call_impl(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return forward_call(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_text.py", line 196, in forward
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     qkv, _ = self.qkv_proj(hidden_states)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return self._call_impl(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return forward_call(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/layers/linear.py", line 375, in forward
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     output = self.quant_method.apply(self, x, bias)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/layers/quantization/fp8.py", line 493, in apply
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return torch.ops.vllm.apply_w8a8_block_fp8_linear(
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1243, in __call__
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return self._op(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/layers/quantization/utils/fp8_utils.py", line 167, in apply_w8a8_block_fp8_linear
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     input_2d = torch.nn.functional.pad(input_2d,
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/functional.py", line 5290, in pad
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return torch._C._nn.pad(input, pad, mode, value)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671] torch.AcceleratorError: CUDA error: invalid configuration argument
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671] 
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671] Traceback (most recent call last):
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/v1/executor/multiproc_executor.py", line 666, in worker_busy_loop
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     output = func(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]              ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return func(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/v1/worker/gpu_worker.py", line 462, in execute_model
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     output = self.model_runner.execute_model(scheduler_output,
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/utils/_contextlib.py", line 120, in decorate_context
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return func(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/v1/worker/gpu_model_runner.py", line 2181, in execute_model
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     model_output = self.model(
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]                    ^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return self._call_impl(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return forward_call(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_vl.py", line 1065, in forward
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     hidden_states = self.language_model(input_ids,
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return self._call_impl(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return forward_call(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_text.py", line 609, in forward
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     hidden_states = self.model(input_ids, positions, intermediate_tensors,
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/compilation/decorators.py", line 223, in __call__
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return self.forward(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_text.py", line 519, in forward
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     layer.compute_attn_output(
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_text.py", line 383, in compute_attn_output
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return self._compute_attn_output(hidden_states, residual,
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_text.py", line 346, in _compute_attn_output
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     hidden_states = self.self_attn(positions=positions,
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return self._call_impl(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return forward_call(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/models/step3_text.py", line 196, in forward
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     qkv, _ = self.qkv_proj(hidden_states)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1773, in _wrapped_call_impl
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return self._call_impl(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/modules/module.py", line 1784, in _call_impl
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return forward_call(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/layers/linear.py", line 375, in forward
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     output = self.quant_method.apply(self, x, bias)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/layers/quantization/fp8.py", line 493, in apply
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return torch.ops.vllm.apply_w8a8_block_fp8_linear(
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/_ops.py", line 1243, in __call__
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return self._op(*args, **kwargs)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/nfs_disk/XXXXXX/afd/vllm-afd/vllm/model_executor/layers/quantization/utils/fp8_utils.py", line 167, in apply_w8a8_block_fp8_linear
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     input_2d = torch.nn.functional.pad(input_2d,
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]                ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]   File "/usr/local/lib/python3.12/dist-packages/torch/nn/functional.py", line 5290, in pad
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]     return torch._C._nn.pad(input, pad, mode, value)
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671] torch.AcceleratorError: CUDA error: invalid configuration argument
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671] CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671] For debugging consider passing CUDA_LAUNCH_BLOCKING=1
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671] Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671] 
(Worker_TP1 pid=26589) ERROR 11-11 15:41:15 [multiproc_executor.py:671]

Metadata

Metadata

Assignees

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions