diff --git a/.buildkite/lm-eval-harness/configs/models-large-rocm.txt b/.buildkite/lm-eval-harness/configs/models-large-rocm.txt index 4fb0b84bc4d8..a9a60f348d6a 100644 --- a/.buildkite/lm-eval-harness/configs/models-large-rocm.txt +++ b/.buildkite/lm-eval-harness/configs/models-large-rocm.txt @@ -1 +1,2 @@ Meta-Llama-4-Maverick-17B-128E-Instruct-FP8.yaml +Qwen3-235B-A22B-Instruct-2507-FP8.yaml diff --git a/.buildkite/scripts/hardware_ci/run-amd-test.sh b/.buildkite/scripts/hardware_ci/run-amd-test.sh index 89736eec1273..aa84d0e8aae9 100755 --- a/.buildkite/scripts/hardware_ci/run-amd-test.sh +++ b/.buildkite/scripts/hardware_ci/run-amd-test.sh @@ -6,6 +6,26 @@ # Multi-node detection: Instead of matching on fragile group names, we detect # multi-node jobs structurally by looking for the bracket command syntax # "[node0_cmds] && [node1_cmds]" or via the NUM_NODES environment variable. +# +############################################################################### +# QUOTING / COMMAND PASSING +# +# Passing commands as positional arguments ($*) is fragile when the command +# string itself contains double quotes, e.g.: +# +# bash run-amd-test.sh "export FLAGS="value" && pytest -m "not slow"" +# +# The outer shell resolves the nested quotes *before* this script runs, so +# the script receives mangled input it cannot fully recover. +# +# Preferred: pass commands via the VLLM_TEST_COMMANDS environment variable: +# +# export VLLM_TEST_COMMANDS='export FLAGS="value" && pytest -m "not slow"' +# bash run-amd-test.sh +# +# Single-quoted assignment preserves all inner double quotes verbatim. +# The $* path is kept for backward compatibility but callers should migrate. +############################################################################### set -o pipefail # Export Python path @@ -80,25 +100,140 @@ is_multi_node() { } ############################################################################### -# Pytest marker re-quoting +# Pytest marker/keyword re-quoting # # When commands are passed through Buildkite -> shell -> $* -> bash -c, -# quotes around pytest -m marker expressions get stripped: +# quotes around multi-word pytest -m/-k expressions get stripped: # pytest -v -s -m 'not cpu_test' v1/core # becomes: # pytest -v -s -m not cpu_test v1/core # # pytest then interprets "cpu_test" as a file path, not part of the marker. -# This function detects unquoted multi-word marker expressions and re-quotes -# them so they survive the final bash -c expansion. +# +# This function detects unquoted expressions after -m/-k and re-quotes them +# by collecting tokens until a recognizable boundary is reached: +# - test path (contains '/') +# - test file (ends with '.py') +# - another pytest flag (--xxx or -x single-char flags) +# - command separator (&& || ; |) +# - environment variable assignment (FOO=bar) +# +# Single-word markers (e.g. -m cpu_test, -m hybrid_model) pass through +# unquoted since they have no spaces and work fine. +# +# Already-quoted expressions (containing literal single quotes) are passed +# through untouched to avoid double-quoting values injected by +# apply_rocm_test_overrides. +# +# NOTE: This ONLY fixes -m/-k flags. It cannot recover arbitrary inner +# double-quotes stripped by the calling shell (see header comment). +# Use VLLM_TEST_COMMANDS to avoid the problem entirely. ############################################################################### - re_quote_pytest_markers() { - local cmds="$1" - # Pattern: -m not -> -m 'not ' - # Handles the common cases: 'not cpu_test', 'not slow_test', etc. - cmds=$(echo "$cmds" | sed -E "s/-m not ([a-zA-Z_][a-zA-Z0-9_]*)/-m 'not \1'/g") - echo "$cmds" + local input="$1" + local output="" + local collecting=false + local marker_buf="" + + # Flatten newlines for consistent tokenization + local flat="${input//$'\n'/ }" + + # Disable globbing to prevent *.py etc. from expanding during read -ra + local restore_glob + restore_glob="$(shopt -p -o noglob 2>/dev/null || true)" + set -o noglob + local -a words + read -ra words <<< "$flat" + eval "$restore_glob" + + for word in "${words[@]}"; do + if $collecting; then + # If the token we're about to collect already contains a literal + # single quote, the expression was already quoted upstream. + # Flush and stop collecting. + if [[ "$word" == *"'"* ]]; then + if [[ -n "$marker_buf" ]]; then + # Should not normally happen (partial buf + quote), flush raw + output+="${marker_buf} " + marker_buf="" + fi + output+="${word} " + collecting=false + continue + fi + + local is_boundary=false + case "$word" in + # Command separators + "&&"|"||"|";"|"|") + is_boundary=true ;; + # Long flags (--ignore, --shard-id, etc.) + --*) + is_boundary=true ;; + # Short flags (-v, -s, -x, etc.) but NOT negative marker tokens + # like "not" which don't start with "-". Also skip -k/-m which + # would start a new marker (handled below). + -[a-zA-Z]) + is_boundary=true ;; + # Test path (contains /) + */*) + is_boundary=true ;; + # Test file (ends with .py, possibly with ::method) + *.py|*.py::*) + is_boundary=true ;; + # Environment variable assignment preceding a command (FOO=bar) + *=*) + # Only treat as boundary if it looks like VAR=value, not + # pytest filter expressions like num_gpus=2 inside markers + if [[ "$word" =~ ^[A-Z_][A-Z0-9_]*= ]]; then + is_boundary=true + fi + ;; + esac + + if $is_boundary; then + # Flush the collected marker expression + if [[ "$marker_buf" == *" "* || "$marker_buf" == *"("* ]]; then + output+="'${marker_buf}' " + else + output+="${marker_buf} " + fi + collecting=false + marker_buf="" + # Check if this boundary word itself starts a new -m/-k + if [[ "$word" == "-m" || "$word" == "-k" ]]; then + output+="${word} " + collecting=true + else + output+="${word} " + fi + else + # Accumulate into marker buffer + if [[ -n "$marker_buf" ]]; then + marker_buf+=" ${word}" + else + marker_buf="${word}" + fi + fi + elif [[ "$word" == "-m" || "$word" == "-k" ]]; then + output+="${word} " + collecting=true + marker_buf="" + else + output+="${word} " + fi + done + + # Flush any trailing marker expression (marker at end of command) + if $collecting && [[ -n "$marker_buf" ]]; then + if [[ "$marker_buf" == *" "* || "$marker_buf" == *"("* ]]; then + output+="'${marker_buf}'" + else + output+="${marker_buf}" + fi + fi + + echo "${output% }" } ############################################################################### @@ -231,11 +366,35 @@ HF_CACHE="$(realpath ~)/huggingface" mkdir -p "${HF_CACHE}" HF_MOUNT="/root/.cache/huggingface" -commands="$*" +# ---- Command source selection ---- +# Prefer VLLM_TEST_COMMANDS (preserves all inner quoting intact). +# Fall back to $* for backward compatibility, but warn that inner +# double-quotes will have been stripped by the calling shell. +if [[ -n "${VLLM_TEST_COMMANDS:-}" ]]; then + commands="${VLLM_TEST_COMMANDS}" + echo "Commands sourced from VLLM_TEST_COMMANDS (quoting preserved)" +else + commands="$*" + if [[ -z "$commands" ]]; then + echo "Error: No test commands provided." >&2 + echo "Usage:" >&2 + echo " Preferred: VLLM_TEST_COMMANDS='...' bash $0" >&2 + echo " Legacy: bash $0 \"commands here\"" >&2 + exit 1 + fi + echo "Commands sourced from positional args (legacy mode)" + echo "WARNING: Inner double-quotes in the command string may have been" + echo " stripped by the calling shell. If you see syntax errors, switch to:" + echo " export VLLM_TEST_COMMANDS='your commands here'" + echo " bash $0" +fi + echo "Raw commands: $commands" # Fix quoting before ROCm overrides (so overrides see correct structure) commands=$(re_quote_pytest_markers "$commands") +echo "After re-quoting: $commands" + commands=$(apply_rocm_test_overrides "$commands") echo "Final commands: $commands" @@ -248,6 +407,18 @@ if [[ -z "$render_gid" ]]; then exit 1 fi +# --- RDMA device passthrough (conditional) --- +# If the host has RDMA devices, pass them through so tests like +# test_moriio_connector can access ibverbs. On hosts without RDMA +# hardware the tests will gracefully skip via _rdma_available(). +RDMA_FLAGS="" +if [ -d /dev/infiniband ]; then + echo "RDMA devices detected on host, enabling passthrough" + RDMA_FLAGS="--device /dev/infiniband --cap-add=IPC_LOCK" +else + echo "No RDMA devices found on host, RDMA tests will be skipped" +fi + # --- Route: multi-node vs single-node --- if is_multi_node "$commands"; then echo "--- Multi-node job detected" @@ -295,6 +466,7 @@ else echo "Render devices: $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES" docker run \ --device /dev/kfd $BUILDKITE_AGENT_META_DATA_RENDER_DEVICES \ + $RDMA_FLAGS \ --network=host \ --shm-size=16gb \ --group-add "$render_gid" \ diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index ffdf4b83c0e2..c5db1ca83634 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -156,8 +156,9 @@ steps: - label: Entrypoints Integration Test (API Server 1) # 100min timeout_in_minutes: 130 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 + optional: true # grade: Blocking working_dir: "/vllm-workspace/tests" fast_check: true @@ -173,8 +174,9 @@ steps: - label: Entrypoints Integration Test (API Server 2) timeout_in_minutes: 50 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 + optional: true # grade: Blocking working_dir: "/vllm-workspace/tests" fast_check: true @@ -192,8 +194,9 @@ steps: - label: Entrypoints Integration Test (Pooling) timeout_in_minutes: 50 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 + optional: true # grade: Blocking working_dir: "/vllm-workspace/tests" fast_check: true @@ -207,8 +210,9 @@ steps: - label: Entrypoints Integration Test (Responses API) timeout_in_minutes: 50 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 + optional: true # grade: Blocking working_dir: "/vllm-workspace/tests" fast_check: true @@ -222,8 +226,9 @@ steps: - label: Distributed Tests (4 GPUs) # 35min timeout_in_minutes: 50 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_4 + optional: true # grade: Blocking working_dir: "/vllm-workspace/tests" num_gpus: 4 @@ -278,14 +283,16 @@ steps: - popd # NEW rlhf examples - pushd ../examples/offline_inference/new_weight_syncing - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_async_new_apis.py - popd - label: Distributed Tests (8 GPUs) # 4min timeout_in_minutes: 10 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_8 + optional: true # grade: Blocking gpu: h100 num_gpus: 8 @@ -380,10 +387,9 @@ steps: - label: V1 Test e2e + engine # 65min timeout_in_minutes: 90 - mirror_hardwares: [amdexperimental] - # The test uses 4 GPUs, but we schedule it on 8-GPU machines for stability. - # See discussion here: https://github.com/vllm-project/vllm/pull/31040 - agent_pool: mi325_8 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_1 + optional: true # grade: Blocking source_file_dependencies: - vllm/ @@ -394,6 +400,34 @@ steps: - pytest -v -s v1/e2e - pytest -v -s v1/engine +- label: V1 Test e2e (2 GPUs) # 65min + timeout_in_minutes: 90 + mirror_hardwares: [amdexperimental, amdproduction] + agent_pool: mi325_2 + optional: true + # grade: Blocking + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + # Only run tests that need exactly 2 GPUs + - pytest -v -s v1/e2e/test_spec_decode.py -k "tensor_parallelism" + +- label: V1 Test e2e (4 GPUs) # 65min + timeout_in_minutes: 90 + mirror_hardwares: [amdexperimental, amdproduction] + # The test uses 4 GPUs, but we schedule it on 8-GPU machines for stability. + # See discussion here: https://github.com/vllm-project/vllm/pull/31040 + agent_pool: mi325_4 + optional: true + # grade: Blocking + source_file_dependencies: + - vllm/ + - tests/v1 + commands: + # Only run tests that need 4 GPUs + - pytest -v -s v1/e2e/test_spec_decode.py -k "eagle_correctness_heavy" + - label: V1 Test entrypoints # 35min timeout_in_minutes: 50 mirror_hardwares: [amdexperimental, amdproduction, amdtentative] @@ -407,8 +441,9 @@ steps: - label: V1 Test others # 42min timeout_in_minutes: 60 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 + optional: true # grade: Blocking source_file_dependencies: - vllm/ @@ -435,8 +470,9 @@ steps: # TODO: Add the "V1 Test attetion (MI300)" test group - label: V1 Test attention (H100) # 10min - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 + optional: true # grade: Blocking timeout_in_minutes: 30 gpu: h100 @@ -540,8 +576,9 @@ steps: - label: Samplers Test # 56min timeout_in_minutes: 75 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 + optional: true # grade: Blocking source_file_dependencies: - vllm/model_executor/layers @@ -553,8 +590,9 @@ steps: - label: LoRA Test %N # 20min each timeout_in_minutes: 30 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 + optional: true # grade: Blocking source_file_dependencies: - vllm/lora @@ -664,8 +702,9 @@ steps: - label: Kernels Quantization Test %N # 64min timeout_in_minutes: 90 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 + optional: true # grade: Blocking source_file_dependencies: - csrc/quantization/ @@ -798,8 +837,9 @@ steps: - label: LM Eval Small Models # 53min timeout_in_minutes: 75 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 + optional: true # grade: Blocking source_file_dependencies: - csrc/ @@ -860,8 +900,9 @@ steps: - label: Basic Models Tests (Other) timeout_in_minutes: 45 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 + optional: true # grade: Blocking torch_nightly: true source_file_dependencies: @@ -902,8 +943,9 @@ steps: - label: Language Models Tests (Extra Standard) %N timeout_in_minutes: 45 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 + optional: true # grade: Blocking torch_nightly: true source_file_dependencies: @@ -923,8 +965,9 @@ steps: - label: Language Models Tests (Hybrid) %N timeout_in_minutes: 75 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 + optional: true # grade: Blocking torch_nightly: true source_file_dependencies: @@ -944,7 +987,7 @@ steps: - label: Language Models Test (Extended Generation) # 80min timeout_in_minutes: 110 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking optional: true @@ -960,7 +1003,7 @@ steps: - label: Language Models Test (PPL) timeout_in_minutes: 110 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking optional: true @@ -972,7 +1015,7 @@ steps: - label: Language Models Test (Extended Pooling) # 36min timeout_in_minutes: 50 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking optional: true @@ -984,7 +1027,7 @@ steps: - label: Language Models Test (MTEB) timeout_in_minutes: 110 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking optional: true @@ -996,7 +1039,7 @@ steps: - label: Multi-Modal Processor Test (CPU) timeout_in_minutes: 60 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 source_file_dependencies: - vllm/ @@ -1008,7 +1051,7 @@ steps: - label: Multi-Modal Processor Test # 44min timeout_in_minutes: 60 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking source_file_dependencies: @@ -1020,7 +1063,7 @@ steps: - label: Multi-Modal Models Test (Standard) # 60min timeout_in_minutes: 100 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking torch_nightly: true @@ -1053,7 +1096,7 @@ steps: - label: Multi-Modal Models Test (Extended) 1 # 60min timeout_in_minutes: 120 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking optional: true @@ -1068,7 +1111,7 @@ steps: - label: Multi-Modal Models Test (Extended) 2 #60min timeout_in_minutes: 120 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking optional: true @@ -1083,7 +1126,7 @@ steps: - label: Multi-Modal Models Test (Extended) 3 # 75min timeout_in_minutes: 150 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking optional: true @@ -1108,7 +1151,7 @@ steps: - pytest -v -s models/quantization - label: Transformers Nightly Models Test - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_1 # grade: Blocking working_dir: "/vllm-workspace/" @@ -1263,8 +1306,9 @@ steps: - label: 2 Node Tests (4 GPUs in total) # 16min timeout_in_minutes: 30 - mirror_hardwares: [amdexperimental, amdmultinode] + mirror_hardwares: [amdexperimental, amdproduction, amdmultinode] agent_pool: mi325_4 + optional: true # grade: Blocking working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -1290,8 +1334,9 @@ steps: - label: Distributed Tests (2 GPUs) # 68min timeout_in_minutes: 90 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_2 + optional: true # grade: Blocking working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -1330,8 +1375,9 @@ steps: - label: Distributed Model Tests (2 GPUs) # 37min timeout_in_minutes: 50 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_2 + optional: true # grade: Blocking working_dir: "/vllm-workspace/tests" num_gpus: 2 @@ -1370,6 +1416,10 @@ steps: - pip install -e ./plugins/prithvi_io_processor_plugin - pytest -v -s plugins_tests/test_io_processor_plugins.py - pip uninstall prithvi_io_processor_plugin -y + # test bge_m3_sparse io_processor plugin + - pip install -e ./plugins/bge_m3_sparse_plugin + - pytest -v -s plugins_tests/test_bge_m3_sparse_io_processor_plugins.py + - pip uninstall bge_m3_sparse_plugin -y # end io_processor plugins test # begin stat_logger plugins test - pip install -e ./plugins/vllm_add_dummy_stat_logger @@ -1441,7 +1491,7 @@ steps: - bash weight_loading/run_model_weight_loading_test.sh -c weight_loading/models-amd.txt - label: Weight Loading Multiple GPU Test - Large Models # optional - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_2 # grade: Blocking working_dir: "/vllm-workspace/tests" @@ -1485,7 +1535,7 @@ steps: ##### A100 test ##### - label: Distributed Tests (A100) # optional - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_4 # grade: Blocking gpu: a100 @@ -1508,7 +1558,7 @@ steps: - label: LM Eval Large Models # optional gpu: a100 optional: true - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_4 # grade: Blocking num_gpus: 4 @@ -1520,11 +1570,11 @@ steps: - export VLLM_WORKER_MULTIPROC_METHOD=spawn - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 -##### H100 test ##### -- label: LM Eval Large Models (H100) # optional +##### FP8 test ##### +- label: LM Eval Large Models (H100) # optional, still use H100 for consistency gpu: h100 optional: true - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_4 # grade: Blocking num_gpus: 4 @@ -1533,13 +1583,13 @@ steps: - csrc/ - vllm/model_executor/layers/quantization commands: - - export VLLM_USE_DEEP_GEMM=0 # We found Triton is faster than DeepGEMM for H100 - - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-hopper.txt --tp-size=4 + - export VLLM_USE_DEEP_GEMM=0 + - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large-rocm.txt --tp-size=4 ##### H200 test ##### - label: Distributed Tests (H200) # optional - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_2 # grade: Blocking gpu: h200 @@ -1599,8 +1649,9 @@ steps: - pytest -s -v test_lm_eval_correctness.py --config-list-file=configs/models-large.txt --tp-size=4 - label: ROCm LM Eval Large Models (8 Card) - mirror_hardwares: [amdproduction] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_8 + optional: true num_gpus: 8 working_dir: "/vllm-workspace/.buildkite/lm-eval-harness" commands: @@ -1659,7 +1710,7 @@ steps: - label: Qwen3-Next-80B-A3B-Instruct MTP Async EPLB Accuracy timeout_in_minutes: 60 - mirror_hardwares: [amdexperimental] + mirror_hardwares: [amdexperimental, amdproduction] agent_pool: mi325_4 # grade: Blocking optional: true @@ -2946,6 +2997,10 @@ steps: - pip install -e ./plugins/prithvi_io_processor_plugin - pytest -v -s plugins_tests/test_io_processor_plugins.py - pip uninstall prithvi_io_processor_plugin -y + # test bge_m3_sparse io_processor plugin + - pip install -e ./plugins/bge_m3_sparse_plugin + - pytest -v -s plugins_tests/test_bge_m3_sparse_io_processor_plugins.py + - pip uninstall bge_m3_sparse_plugin -y # end io_processor plugins test # begin stat_logger plugins test - pip install -e ./plugins/vllm_add_dummy_stat_logger @@ -3227,4 +3282,4 @@ steps: num_gpus: 4 working_dir: "/vllm-workspace" commands: - - bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040 \ No newline at end of file + - bash .buildkite/scripts/scheduled_integration_test/qwen3_next_mtp_async_eplb.sh 0.8 1319 8040 diff --git a/.buildkite/test_areas/distributed.yaml b/.buildkite/test_areas/distributed.yaml index 9b5b002f4b70..0a75bc50e484 100644 --- a/.buildkite/test_areas/distributed.yaml +++ b/.buildkite/test_areas/distributed.yaml @@ -103,7 +103,8 @@ steps: - VLLM_ALLOW_INSECURE_SERIALIZATION=1 RAY_DEDUP_LOGS=0 python3 rlhf_colocate.py # NEW rlhf examples - cd new_weight_syncing - - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_nccl.py + - VLLM_ALLOW_INSECURE_SERIALIZATION=1 python3 rlhf_ipc.py - label: Distributed Tests (8 GPUs)(H100) timeout_in_minutes: 10 diff --git a/.buildkite/test_areas/engine.yaml b/.buildkite/test_areas/engine.yaml index 19cd91370e64..b5b3eeb6d728 100644 --- a/.buildkite/test_areas/engine.yaml +++ b/.buildkite/test_areas/engine.yaml @@ -14,7 +14,7 @@ steps: commands: - pytest -v -s engine test_sequence.py test_config.py test_logger.py test_vllm_port.py -- label: V1 e2e + engine +- label: V1 e2e + engine (1 GPU) timeout_in_minutes: 45 source_file_dependencies: - vllm/ @@ -36,3 +36,35 @@ steps: commands: - pytest -v -s v1/e2e - pytest -v -s v1/engine + +- label: V1 e2e (2 GPUs) + timeout_in_minutes: 60 # TODO: Fix timeout after we have more confidence in the test stability + optional: true + num_devices: 2 + source_file_dependencies: + - vllm/ + - tests/v1/e2e + commands: + # Only run tests that need exactly 2 GPUs + - pytest -v -s v1/e2e/test_spec_decode.py -k "tensor_parallelism" + mirror: + amd: + device: mi325_2 + depends_on: + - image-build-amd + +- label: V1 e2e (4 GPUs) + timeout_in_minutes: 60 # TODO: Fix timeout after we have more confidence in the test stability + optional: true + num_devices: 4 + source_file_dependencies: + - vllm/ + - tests/v1/e2e + commands: + # Only run tests that need 4 GPUs + - pytest -v -s v1/e2e/test_spec_decode.py -k "eagle_correctness_heavy" + mirror: + amd: + device: mi325_4 + depends_on: + - image-build-amd diff --git a/.buildkite/test_areas/expert_parallelism.yaml b/.buildkite/test_areas/expert_parallelism.yaml index 9a10476ed78a..1443d847eaf5 100644 --- a/.buildkite/test_areas/expert_parallelism.yaml +++ b/.buildkite/test_areas/expert_parallelism.yaml @@ -20,4 +20,19 @@ steps: - tests/distributed/test_eplb_execute.py commands: - pytest -v -s distributed/test_eplb_execute.py - - pytest -v -s distributed/test_eplb_spec_decode.py \ No newline at end of file + - pytest -v -s distributed/test_eplb_spec_decode.py + +- label: Elastic EP Scaling Test + timeout_in_minutes: 20 + device: b200 + optional: true + working_dir: "/vllm-workspace/tests" + num_devices: 4 + source_file_dependencies: + - vllm/distributed/ + - vllm/engine/ + - vllm/executor/ + - vllm/compilation/ + - tests/distributed/ + commands: + - pytest -v -s distributed/test_elastic_ep.py diff --git a/.buildkite/test_areas/kernels.yaml b/.buildkite/test_areas/kernels.yaml index afc8fc49a2aa..e1ecfeb8415f 100644 --- a/.buildkite/test_areas/kernels.yaml +++ b/.buildkite/test_areas/kernels.yaml @@ -70,7 +70,7 @@ steps: - tests/kernels/moe/test_batched_deepgemm.py - tests/kernels/attention/test_deepgemm_attention.py commands: - - pytest -v -s kernels/quantization/test_block_fp8.py -k deep_gemm + - pytest -v -s kernels/quantization/test_block_fp8.py - pytest -v -s kernels/moe/test_deepgemm.py - pytest -v -s kernels/moe/test_batched_deepgemm.py - pytest -v -s kernels/attention/test_deepgemm_attention.py @@ -155,5 +155,14 @@ steps: commands: - pytest -v -s kernels/moe/test_deepep_deepgemm_moe.py - pytest -v -s kernels/moe/test_deepep_moe.py - - pytest -v -s kernels/moe/test_pplx_cutlass_moe.py - # - pytest -v -s kernels/moe/test_pplx_moe.py - failing on main + +- label: Kernels Fp4 MoE Test (B200) + timeout_in_minutes: 60 + device: b200 + num_devices: 1 + optional: true + commands: + - pytest -v -s kernels/moe/test_cutedsl_moe.py + - pytest -v -s kernels/moe/test_flashinfer_moe.py + - pytest -v -s kernels/moe/test_nvfp4_moe.py + - pytest -v -s kernels/moe/test_ocp_mx_moe.py diff --git a/.buildkite/test_areas/misc.yaml b/.buildkite/test_areas/misc.yaml index 69390cd6d373..d8957c217755 100644 --- a/.buildkite/test_areas/misc.yaml +++ b/.buildkite/test_areas/misc.yaml @@ -9,6 +9,7 @@ steps: - tests/v1 commands: - uv pip install --system -r /vllm-workspace/requirements/kv_connectors.txt + - export VLLM_WORKER_MULTIPROC_METHOD=spawn # split the test to avoid interference - pytest -v -s -m 'not cpu_test' v1/core - pytest -v -s v1/executor diff --git a/.buildkite/test_areas/plugins.yaml b/.buildkite/test_areas/plugins.yaml index ccc54b47abd4..16f9abccf6e1 100644 --- a/.buildkite/test_areas/plugins.yaml +++ b/.buildkite/test_areas/plugins.yaml @@ -19,6 +19,10 @@ steps: - pip install -e ./plugins/prithvi_io_processor_plugin - pytest -v -s plugins_tests/test_io_processor_plugins.py - pip uninstall prithvi_io_processor_plugin -y + # test bge_m3_sparse io_processor plugin + - pip install -e ./plugins/bge_m3_sparse_plugin + - pytest -v -s plugins_tests/test_bge_m3_sparse_io_processor_plugins.py + - pip uninstall bge_m3_sparse_plugin -y # end io_processor plugins test # begin stat_logger plugins test - pip install -e ./plugins/vllm_add_dummy_stat_logger diff --git a/.github/.bc-linter.yml b/.github/.bc-linter.yml deleted file mode 100644 index 443dfa45af22..000000000000 --- a/.github/.bc-linter.yml +++ /dev/null @@ -1,24 +0,0 @@ -# doc: https://github.com/pytorch/test-infra/blob/main/tools/stronghold/docs/bc_linter_config.md -version: 1 -paths: -# We temporarily disable globally, and will only enable with `annotations.include` -# include: -# - "vllm/v1/attetion/*.py" -# - "vllm/v1/core/*.py" -exclude: - - "**/*.py" - -scan: - functions: true # check free functions and methods - classes: true # check classes/dataclasses - public_only: true # ignore names starting with "_" at any level - -annotations: - include: # decorators that force‑include a symbol - - name: "bc_linter_include" # matched by simple name or dotted suffix - propagate_to_members: false # for classes, include methods/inner classes - exclude: # decorators that force‑exclude a symbol - - name: "bc_linter_skip" # matched by simple name or dotted suffix - propagate_to_members: true # for classes, exclude methods/inner classes - -excluded_violations: [] # e.g. ["ParameterRenamed", "FieldTypeChanged"] diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index adf50a185e55..653d6c42e9af 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,7 +2,7 @@ # for more info about CODEOWNERS file # This lists cover the "core" components of vLLM that require careful review -/vllm/compilation @zou3519 @youkaichao @ProExpertProg +/vllm/compilation @zou3519 @youkaichao @ProExpertProg @BoyuanFeng /vllm/distributed/kv_transfer @NickLucche @ApostaC @orozery /vllm/lora @jeejeelee /vllm/model_executor/layers/attention @LucasWilkinson @MatthewBonanni @@ -54,11 +54,14 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /vllm/v1/structured_output @mgoin @russellb @aarnphm @benchislett /vllm/v1/kv_cache_interface.py @heheda12345 /vllm/v1/kv_offload @ApostaC @orozery -/vllm/v1/worker/gpu/kv_connector.py @orozery +/vllm/v1/engine @njhill +/vllm/v1/executor @njhill +/vllm/v1/worker @njhill /vllm/v1/worker/kv_connector_model_runner_mixin.py @orozery @NickLucche # Model runner V2 -/vllm/v1/worker/gpu @WoosukKwon +/vllm/v1/worker/gpu @WoosukKwon @njhill +/vllm/v1/worker/gpu/kv_connector.py @orozery # Test ownership /.buildkite/lm-eval-harness @mgoin diff --git a/.github/mergify.yml b/.github/mergify.yml index 080767ca7218..9c53342d1737 100644 --- a/.github/mergify.yml +++ b/.github/mergify.yml @@ -259,8 +259,7 @@ pull_request_rules: - files=benchmarks/run_structured_output_benchmark.sh - files=docs/features/structured_outputs.md - files=examples/offline_inference/structured_outputs.py - - files=examples/online_serving/openai_chat_completion_structured_outputs.py - - files=examples/online_serving/openai_chat_completion_structured_outputs_with_reasoning.py + - files=examples/online_serving/structured_outputs/structured_outputs.py - files~=^tests/v1/structured_output/ - files=tests/v1/entrypoints/llm/test_struct_output_generate.py - files~=^vllm/v1/structured_output/ diff --git a/.github/workflows/bc-lint.yml b/.github/workflows/bc-lint.yml deleted file mode 100644 index 823695a92132..000000000000 --- a/.github/workflows/bc-lint.yml +++ /dev/null @@ -1,29 +0,0 @@ -name: BC Lint - -on: - pull_request: - types: - - opened - - synchronize - - reopened - - labeled - - unlabeled - -jobs: - bc_lint: - if: github.repository_owner == 'vllm-project' - runs-on: ubuntu-latest - steps: - - name: Run BC Lint Action - uses: pytorch/test-infra/.github/actions/bc-lint@main - with: - repo: ${{ github.event.pull_request.head.repo.full_name }} - base_sha: ${{ github.event.pull_request.base.sha }} - head_sha: ${{ github.event.pull_request.head.sha }} - suppression: ${{ contains(github.event.pull_request.labels.*.name, 'suppress-bc-linter') }} - docs_link: 'https://github.com/pytorch/test-infra/wiki/BC-Linter' - config_dir: .github - -concurrency: - group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.sha }} - cancel-in-progress: true diff --git a/.gitignore b/.gitignore index 8e864d090c9d..795071bd77f7 100644 --- a/.gitignore +++ b/.gitignore @@ -3,6 +3,8 @@ # vllm-flash-attn built from source vllm/vllm_flash_attn/* +!vllm/vllm_flash_attn/__init__.py +!vllm/vllm_flash_attn/flash_attn_interface.py # OpenAI triton kernels copied from source vllm/third_party/triton_kernels/* diff --git a/CMakeLists.txt b/CMakeLists.txt index 55127a514f1f..65df275cd314 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -725,7 +725,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # CUTLASS MoE kernels # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works - # on Hopper). get_cutlass_(pplx_)moe_mm_data should only be compiled + # on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled # if it's possible to compile MoE kernels that use its output. cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) @@ -771,6 +771,33 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") endif() endif() + # Expert-specialization MXFP8 blockscaled grouped kernels (SM100+). + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND ES_MXFP8_GROUPED_MM_ARCHS) + set(SRCS + "csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu" + "csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${ES_MXFP8_GROUPED_MM_ARCHS}") + list(APPEND VLLM_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_ES_MXFP8_GROUPED_MM_SM100=1") + message(STATUS "Building ES MXFP8 grouped kernels for archs: ${ES_MXFP8_GROUPED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 + AND ES_MXFP8_GROUPED_MM_ARCHS) + message(STATUS "Not building ES MXFP8 grouped kernels as CUDA Compiler version is " + "not >= 12.8.") + else() + message(STATUS "Not building ES MXFP8 grouped kernels as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() + # DeepSeek V3 fused A GEMM kernel (requires SM 9.0+, Hopper and later) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) cuda_archs_loose_intersection(DSV3_FUSED_A_GEMM_ARCHS "9.0a;10.0f;11.0f" "${CUDA_ARCHS}") @@ -971,7 +998,8 @@ set(VLLM_MOE_EXT_SRC if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_MOE_EXT_SRC "csrc/moe/moe_wna16.cu" - "csrc/moe/grouped_topk_kernels.cu") + "csrc/moe/grouped_topk_kernels.cu" + "csrc/moe/router_gemm.cu") endif() if(VLLM_GPU_LANG STREQUAL "CUDA") diff --git a/benchmarks/attention_benchmarks/__init__.py b/benchmarks/attention_benchmarks/__init__.py index df7a6328569d..2d21288700a5 100644 --- a/benchmarks/attention_benchmarks/__init__.py +++ b/benchmarks/attention_benchmarks/__init__.py @@ -15,7 +15,6 @@ BenchmarkConfig, BenchmarkResult, MockLayer, - MockModelConfig, ResultsFormatter, get_attention_scale, is_mla_backend, @@ -36,7 +35,6 @@ "ResultsFormatter", # Mock objects "MockLayer", - "MockModelConfig", # Utilities "setup_mla_dims", "get_attention_scale", diff --git a/benchmarks/attention_benchmarks/common.py b/benchmarks/attention_benchmarks/common.py index 1de8bb0a55b7..6bba93e50238 100644 --- a/benchmarks/attention_benchmarks/common.py +++ b/benchmarks/attention_benchmarks/common.py @@ -10,7 +10,6 @@ from pathlib import Path from typing import Any -import numpy as np import torch from batch_spec import get_batch_type, parse_batch_spec from rich.console import Console @@ -62,10 +61,7 @@ def get_text_config(self): # Import AttentionLayerBase at module level to avoid circular dependencies try: from vllm.model_executor.layers.attention_layer_base import AttentionLayerBase - - _HAS_ATTENTION_LAYER_BASE = True except ImportError: - _HAS_ATTENTION_LAYER_BASE = False AttentionLayerBase = object # Fallback @@ -167,95 +163,6 @@ def get_kv_cache_spec(self): return self._kv_cache_spec -class MockModelConfig: - """Mock model configuration.""" - - def __init__( - self, - num_q_heads: int, - num_kv_heads: int, - head_dim: int, - dtype: torch.dtype = torch.float16, - max_model_len: int = 32768, - ): - self._n_q = num_q_heads - self._n_kv = num_kv_heads - self._d = head_dim - self.dtype = dtype - self.max_model_len = max_model_len - - def get_num_attention_heads(self, _=None) -> int: - return self._n_q - - def get_num_kv_heads(self, _=None) -> int: - return self._n_kv - - def get_head_size(self) -> int: - return self._d - - def get_num_layers(self) -> int: - """Mock method for layer count queries.""" - return 1 - - def get_sliding_window_for_layer(self, _layer_idx: int): - """Mock method for sliding window queries.""" - return None - - def get_logits_soft_cap_for_layer(self, _layer_idx: int): - """Mock method for logits soft cap queries.""" - return None - - def get_sm_scale_for_layer(self, _layer_idx: int) -> float: - """Mock method for SM scale queries.""" - return 1.0 / (self.get_head_size() ** 0.5) - - -class MockParallelConfig: - """Mock parallel configuration.""" - - pass - - -class MockCompilationConfig: - """Mock compilation configuration.""" - - def __init__(self): - self.full_cuda_graph = False - self.static_forward_context = {} - - -class MockVLLMConfig: - """Mock VLLM configuration.""" - - def __init__(self): - self.compilation_config = MockCompilationConfig() - - -class MockRunner: - """Mock GPU runner for metadata builders.""" - - def __init__( - self, - seq_lens: np.ndarray, - query_start_locs: np.ndarray, - device: torch.device, - num_q_heads: int, - num_kv_heads: int, - head_dim: int, - dtype: torch.dtype, - ): - self.model_config = MockModelConfig(num_q_heads, num_kv_heads, head_dim, dtype) - self.parallel_config = MockParallelConfig() - self.vllm_config = MockVLLMConfig() - self.seq_lens_np = seq_lens - self.query_start_loc_np = query_start_locs - self.device = device - self.attention_chunk_size = None - self.num_query_heads = num_q_heads - self.num_kv_heads = num_kv_heads - self.dtype = dtype - - @dataclass class ParameterSweep: """Configuration for sweeping a backend parameter.""" diff --git a/benchmarks/backend_request_func.py b/benchmarks/backend_request_func.py index 831b76b66e09..a69637bfc437 100644 --- a/benchmarks/backend_request_func.py +++ b/benchmarks/backend_request_func.py @@ -649,9 +649,3 @@ def get_tokenizer( "sglang": async_request_openai_completions, "llama.cpp": async_request_openai_completions, } - -OPENAI_COMPATIBLE_BACKENDS = [ - k - for k, v in ASYNC_REQUEST_FUNCS.items() - if v in (async_request_openai_completions, async_request_openai_chat_completions) -] diff --git a/benchmarks/benchmark_utils.py b/benchmarks/benchmark_utils.py index f0d661f9d534..5865473e9542 100644 --- a/benchmarks/benchmark_utils.py +++ b/benchmarks/benchmark_utils.py @@ -1,78 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import argparse -import json -import math -import os import time from types import TracebackType -from typing import Any - - -def convert_to_pytorch_benchmark_format( - args: argparse.Namespace, metrics: dict[str, list], extra_info: dict[str, Any] -) -> list: - """ - Save the benchmark results in the format used by PyTorch OSS benchmark with - on metric per record - https://github.com/pytorch/pytorch/wiki/How-to-integrate-with-PyTorch-OSS-benchmark-database - """ - records = [] - if not os.environ.get("SAVE_TO_PYTORCH_BENCHMARK_FORMAT", False): - return records - - for name, benchmark_values in metrics.items(): - record = { - "benchmark": { - "name": "vLLM benchmark", - "extra_info": { - "args": vars(args), - }, - }, - "model": { - "name": args.model, - }, - "metric": { - "name": name, - "benchmark_values": benchmark_values, - "extra_info": extra_info, - }, - } - - tp = record["benchmark"]["extra_info"]["args"].get("tensor_parallel_size") - # Save tensor_parallel_size parameter if it's part of the metadata - if not tp and "tensor_parallel_size" in extra_info: - record["benchmark"]["extra_info"]["args"]["tensor_parallel_size"] = ( - extra_info["tensor_parallel_size"] - ) - - records.append(record) - - return records - - -class InfEncoder(json.JSONEncoder): - def clear_inf(self, o: Any): - if isinstance(o, dict): - return {k: self.clear_inf(v) for k, v in o.items()} - elif isinstance(o, list): - return [self.clear_inf(v) for v in o] - elif isinstance(o, float) and math.isinf(o): - return "inf" - return o - - def iterencode(self, o: Any, *args, **kwargs) -> Any: - return super().iterencode(self.clear_inf(o), *args, **kwargs) - - -def write_to_json(filename: str, records: list) -> None: - with open(filename, "w") as f: - json.dump( - records, - f, - cls=InfEncoder, - default=lambda o: f"<{type(o).__name__} object is not JSON serializable>", - ) # Collect time and generate time metrics diff --git a/benchmarks/cutlass_benchmarks/utils.py b/benchmarks/cutlass_benchmarks/utils.py index b4f3c6bf94ed..6cbcf6b68c89 100644 --- a/benchmarks/cutlass_benchmarks/utils.py +++ b/benchmarks/cutlass_benchmarks/utils.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Cutlass bench utils -from collections.abc import Iterable import torch @@ -86,15 +85,3 @@ def make_rand_sparse_tensors( # Compressed B, Metadata, Original A, B return b_compressed, e, a, b - - -def make_n_rand_sparse_tensors( - num_tensors: int, dtype: torch.dtype, m: int, n: int, k: int -) -> tuple[Iterable[torch.Tensor], Iterable[torch.Tensor]]: - ABs = [] - for _ in range(num_tensors): - b_comp, e, a, b = make_rand_sparse_tensors(dtype, m, n, k) - if b_comp is not None: - ABs.append(make_rand_sparse_tensors(dtype, m, n, k)) - BComps, Es, As, Bs = zip(*ABs) - return list(BComps), list(Es), list(As), list(Bs) diff --git a/benchmarks/disagg_benchmarks/rate_limiter.py b/benchmarks/disagg_benchmarks/rate_limiter.py deleted file mode 100644 index 87ac8cb6ab1a..000000000000 --- a/benchmarks/disagg_benchmarks/rate_limiter.py +++ /dev/null @@ -1,45 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -import time - - -class RateLimiter: - """Token bucket rate limiter implementation""" - - def __init__(self, rate_limit): - self.rate_limit = rate_limit # Requests per second - self.num_available_tokens = rate_limit # Available tokens - self.last_refill = time.monotonic() # Last token refill time - self.lock = asyncio.Lock() # Synchronization lock - - async def acquire(self): - """Acquire a token from the rate limiter""" - while True: - async with self.lock: - current_time = time.monotonic() - elapsed = current_time - self.last_refill - - # Refill num_available_tokens if more than 1 second has passed - if elapsed > 1.0: - self.num_available_tokens = self.rate_limit - self.last_refill = current_time - - # Check if num_available_tokens are available - if self.num_available_tokens > 0: - self.num_available_tokens -= 1 - return True - - # Calculate wait time if no num_available_tokens available - wait_time = 1.0 - elapsed - await asyncio.sleep(wait_time) - - async def __aenter__(self): - """Enter async context manager - acquire token""" - await self.acquire() - return self - - async def __aexit__(self, exc_type, exc_value, traceback): - """Exit async context manager - no cleanup needed""" - pass diff --git a/benchmarks/disagg_benchmarks/request_queue.py b/benchmarks/disagg_benchmarks/request_queue.py deleted file mode 100644 index 410bcb956050..000000000000 --- a/benchmarks/disagg_benchmarks/request_queue.py +++ /dev/null @@ -1,39 +0,0 @@ -# SPDX-License-Identifier: Apache-2.0 -# SPDX-FileCopyrightText: Copyright contributors to the vLLM project - -import asyncio -from collections import deque - - -class RequestQueue: - """Request queue manager with concurrency control""" - - def __init__(self, max_concurrent, max_queue_size): - # Maximum concurrent requests - self.max_concurrent = max_concurrent - self.max_queue_size = max_queue_size # Maximum queue size - # Concurrency control - self.semaphore = asyncio.Semaphore(max_concurrent) - self.queue = deque() # Request queue - self.queue_size = 0 # Current queue size - self.lock = asyncio.Lock() # Sync queue Lock - - async def enqueue(self, task): - """Add a request task to the queue""" - async with self.lock: - if self.queue_size >= self.max_queue_size: - return False - - self.queue.append(task) - self.queue_size += 1 - return True - - async def process(self): - """Process queued requests using semaphore for concurrency control""" - while True: - if self.queue: - async with self.semaphore, self.lock: - task = self.queue.popleft() - self.queue_size -= 1 - await task - await asyncio.sleep(0.01) # Yield control to event loop diff --git a/benchmarks/kernels/benchmark_device_communicators.py b/benchmarks/kernels/benchmark_device_communicators.py index 7b453fe7b680..d1005461ab93 100644 --- a/benchmarks/kernels/benchmark_device_communicators.py +++ b/benchmarks/kernels/benchmark_device_communicators.py @@ -30,6 +30,9 @@ from torch.distributed import ProcessGroup from vllm.distributed.device_communicators.custom_all_reduce import CustomAllreduce +from vllm.distributed.device_communicators.flashinfer_all_reduce import ( + FlashInferAllReduce, +) from vllm.distributed.device_communicators.pynccl import ( PyNcclCommunicator, register_nccl_symmetric_ops, @@ -44,7 +47,7 @@ logger = init_logger(__name__) # Default sequence lengths to benchmark -DEFAULT_SEQUENCE_LENGTHS = [128, 512, 1024, 2048, 4096, 8192] +DEFAULT_SEQUENCE_LENGTHS = [16, 64, 128, 512, 1024, 2048, 4096, 8192] # Fixed hidden size and dtype for all benchmarks HIDDEN_SIZE = 8192 @@ -81,6 +84,7 @@ def __init__( self.symm_mem_comm = None self.symm_mem_comm_multimem = None self.symm_mem_comm_two_shot = None + self.fi_ar_comm = None self._init_communicators() @@ -161,6 +165,22 @@ def _init_communicators(self): ) self.symm_mem_comm_two_shot = None + try: + self.fi_ar_comm = FlashInferAllReduce( + group=self.cpu_group, + device=self.device, + ) + if not self.fi_ar_comm.disabled: + logger.info("Rank %s: FlashInferAllReduce initialized", self.rank) + else: + logger.info("Rank %s: FlashInferAllReduce disabled", self.rank) + self.fi_ar_comm = None + except Exception as e: + logger.warning( + "Rank %s: Failed to initialize FlashInferAllReduce: %s", self.rank, e + ) + self.fi_ar_comm = None + def benchmark_allreduce( self, sequence_length: int, num_warmup: int, num_trials: int ) -> dict[str, float]: @@ -180,7 +200,8 @@ def benchmark_allreduce( lambda t, c=comm: c.custom_all_reduce(t), lambda t, c=comm: c.should_custom_ar(t), comm.capture(), - "1stage", # env variable value + {"VLLM_CUSTOM_ALLREDUCE_ALGO": "1stage"}, + None, # no destroy function ) ) # CustomAllreduce two-shot @@ -190,7 +211,8 @@ def benchmark_allreduce( lambda t, c=comm: c.custom_all_reduce(t), lambda t, c=comm: c.should_custom_ar(t), comm.capture(), - "2stage", # env variable value + {"VLLM_CUSTOM_ALLREDUCE_ALGO": "2stage"}, + None, # no destroy function ) ) @@ -202,7 +224,8 @@ def benchmark_allreduce( lambda t, c=comm: c.all_reduce(t), lambda t: True, # Always available if initialized nullcontext(), - None, # no env variable needed + {}, # no env variable needed + None, # no destroy function ) ) communicators.append( @@ -211,7 +234,8 @@ def benchmark_allreduce( lambda t: torch.ops.vllm.all_reduce_symmetric_with_copy(t), lambda t: True, # Always available if initialized nullcontext(), - None, # no env variable needed + {}, # no env variable needed + None, # no destroy function ) ) @@ -223,7 +247,8 @@ def benchmark_allreduce( lambda t, c=comm: c.all_reduce(t), lambda t, c=comm: c.should_use_symm_mem(t), nullcontext(), - None, # no env variable needed + {}, # no env variable needed + None, # no destroy function ) ) @@ -235,29 +260,67 @@ def benchmark_allreduce( lambda t, c=comm: c.all_reduce(t), lambda t, c=comm: c.should_use_symm_mem(t), nullcontext(), - None, # no env variable needed + {}, # no env variable needed + None, # no destroy function needed ) ) - # Benchmark each communicator - for name, allreduce_fn, should_use_fn, context, env_var in communicators: - # Set environment variable if needed - if env_var is not None: - os.environ["VLLM_CUSTOM_ALLREDUCE_ALGO"] = env_var - else: - # Clear the environment variable to avoid interference - os.environ.pop("VLLM_CUSTOM_ALLREDUCE_ALGO", None) - - latency = self.benchmark_allreduce_single( - sequence_length, - allreduce_fn, - should_use_fn, - context, - num_warmup, - num_trials, + if self.fi_ar_comm is not None: + comm = self.fi_ar_comm + communicators.append( + ( + "flashinfer_trtllm", + lambda t, c=comm: c.all_reduce(t), + lambda t, c=comm: c.should_use_fi_ar(t), + nullcontext(), + {"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "trtllm"}, + lambda c=comm: c.destroy(), + ) ) - if latency is not None: - results[name] = latency + communicators.append( + ( + "flashinfer_mnnvl", + lambda t, c=comm: c.all_reduce(t), + lambda t, c=comm: c.should_use_fi_ar(t), + nullcontext(), + {"VLLM_FLASHINFER_ALLREDUCE_BACKEND": "mnnvl"}, + lambda c=comm: c.destroy(), + ) + ) + + # Benchmark each communicator + for ( + name, + allreduce_fn, + should_use_fn, + context, + env_dict, + destroy_fn, + ) in communicators: + # Save original values and apply new environment variables + saved_env = {key: os.environ.get(key) for key in env_dict} + for key, value in env_dict.items(): + os.environ[key] = value + try: + latency = self.benchmark_allreduce_single( + sequence_length, + allreduce_fn, + should_use_fn, + context, + num_warmup, + num_trials, + ) + if latency is not None: + results[name] = latency + finally: + if destroy_fn is not None: + destroy_fn() + # Restore environment variables to their original state + for key, original_value in saved_env.items(): + if original_value is None: + os.environ.pop(key, None) + else: + os.environ[key] = original_value return results diff --git a/benchmarks/kernels/benchmark_fused_collective.py b/benchmarks/kernels/benchmark_fused_collective.py index 633529edf16d..e18f6a7580fb 100644 --- a/benchmarks/kernels/benchmark_fused_collective.py +++ b/benchmarks/kernels/benchmark_fused_collective.py @@ -5,8 +5,11 @@ Benchmark for FlashInfer fused collective operations vs standard operations. This benchmark compares: -1. FlashInfer's allreduce_fusion (fused allreduce + rmsnorm + optional quant) -2. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations +1. FlashInfer's allreduce_fusion with trtllm backend + (fused allreduce + rmsnorm + optional FP8/FP4 quant) +2. FlashInfer's allreduce_fusion with mnnvl backend + (fused allreduce + rmsnorm only, no quantization support) +3. Standard tensor_model_parallel_all_reduce + separate rmsnorm/quant operations Usage with torchrun: torchrun --nproc_per_node=2 benchmark_fused_collective.py @@ -48,8 +51,12 @@ logger = init_logger(__name__) # Try to import FlashInfer +TorchDistBackend = None try: import flashinfer.comm as flashinfer_comm # type: ignore + from flashinfer.comm.mnnvl import ( # type: ignore + TorchDistBackend, + ) if not ( hasattr(flashinfer_comm, "allreduce_fusion") @@ -74,11 +81,15 @@ 8: 64 * MiB, # 64MB } -# Global workspace tensor for FlashInfer -_FI_WORKSPACE = None +# Global workspace tensors for FlashInfer (keyed by backend name) +_FI_WORKSPACES: dict = {} + +# Backends to benchmark +FLASHINFER_BACKENDS = ["trtllm", "mnnvl"] def setup_flashinfer_workspace( + backend: str, world_size: int, rank: int, hidden_dim: int, @@ -86,41 +97,54 @@ def setup_flashinfer_workspace( dtype: torch.dtype, ): """Setup FlashInfer workspace for fused allreduce operations.""" - global _FI_WORKSPACE + global FI_WORKSPACES if flashinfer_comm is None: - return None, None + return None if world_size not in _FI_MAX_SIZES: logger.warning("FlashInfer not supported for world size %s", world_size) - return None, None + return None try: + kwargs = {} + if TorchDistBackend is not None: + kwargs["comm_backend"] = TorchDistBackend(group=dist.group.WORLD) + workspace = flashinfer_comm.create_allreduce_fusion_workspace( - backend="trtllm", + backend=backend, world_size=world_size, rank=rank, max_token_num=max_token_num, hidden_dim=hidden_dim, dtype=dtype, + **kwargs, ) - _FI_WORKSPACE = workspace + _FI_WORKSPACES[backend] = workspace return workspace except Exception as e: - logger.error("Failed to setup FlashInfer workspace: %s", e) + logger.error( + "Failed to setup FlashInfer workspace (backend=%s): %s", backend, e + ) return None -def cleanup_flashinfer_workspace(workspace): - """Cleanup FlashInfer workspace.""" - if flashinfer_comm is None or workspace is None: +def cleanup_flashinfer_workspaces(): + """Cleanup all FlashInfer workspaces.""" + if flashinfer_comm is None: return - try: - workspace.destroy() - except Exception as e: - logger.error("Failed to cleanup FlashInfer workspace: %s", e) + for backend, workspace in _FI_WORKSPACES.items(): + try: + workspace.destroy() + except Exception as e: + logger.error( + "Failed to cleanup FlashInfer workspace (backend=%s): %s", + backend, + e, + ) + _FI_WORKSPACES.clear() class FlashInferFusedAllReduceParams: @@ -134,7 +158,7 @@ def __init__( self.fp32_acc = True self.max_token_num = max_token_num - def get_trtllm_fused_allreduce_kwargs(self): + def get_flashinfer_fused_allreduce_kwargs(self): return { "launch_with_pdl": self.launch_with_pdl, "fp32_acc": self.fp32_acc, @@ -147,11 +171,12 @@ def flashinfer_fused_allreduce_rmsnorm( rms_gamma: torch.Tensor, rms_eps: float, allreduce_params: "FlashInferFusedAllReduceParams", + workspace: object, use_oneshot: bool, norm_out: torch.Tensor | None = None, ): """FlashInfer fused allreduce + rmsnorm operation.""" - if flashinfer_comm is None or _FI_WORKSPACE is None: + if flashinfer_comm is None or workspace is None: raise RuntimeError("FlashInfer not available or workspace not initialized") if norm_out is None: @@ -160,9 +185,13 @@ def flashinfer_fused_allreduce_rmsnorm( else: residual_out = input_tensor + layout_code = None + if workspace.backend == "trtllm": + layout_code = flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4 + flashinfer_comm.allreduce_fusion( input=input_tensor, - workspace=_FI_WORKSPACE, + workspace=workspace, pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNorm, residual_in=residual, residual_out=residual_out, @@ -171,10 +200,10 @@ def flashinfer_fused_allreduce_rmsnorm( rms_eps=rms_eps, quant_out=None, scale_out=None, - layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, + layout_code=layout_code, scale_factor=None, use_oneshot=use_oneshot, - **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + **allreduce_params.get_flashinfer_fused_allreduce_kwargs(), ) @@ -185,12 +214,16 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant( rms_eps: float, scale_factor: torch.Tensor, allreduce_params: FlashInferFusedAllReduceParams, + workspace: object, use_oneshot: bool = True, norm_out: torch.Tensor | None = None, quant_out: torch.Tensor | None = None, ): - """FlashInfer fused allreduce + rmsnorm + FP8 quantization.""" - if flashinfer_comm is None or _FI_WORKSPACE is None: + """FlashInfer fused allreduce + rmsnorm + FP8 quantization. + + Note: Only supported by the trtllm backend. + """ + if flashinfer_comm is None or workspace is None: raise RuntimeError("FlashInfer not available or workspace not initialized") if norm_out is None: @@ -201,7 +234,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant( flashinfer_comm.allreduce_fusion( input=input_tensor, - workspace=_FI_WORKSPACE, + workspace=workspace, pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP8Quant, residual_in=residual, residual_out=residual_out, @@ -213,7 +246,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp8_quant( layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, scale_factor=scale_factor, use_oneshot=use_oneshot, - **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + **allreduce_params.get_flashinfer_fused_allreduce_kwargs(), ) @@ -224,13 +257,17 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant( rms_eps: float, input_global_scale: torch.Tensor, allreduce_params: FlashInferFusedAllReduceParams, + workspace: object, quant_out: torch.Tensor, use_oneshot: bool, output_scale: torch.Tensor, norm_out: torch.Tensor | None = None, ): - """FlashInfer fused allreduce + rmsnorm + FP4 quantization.""" - if flashinfer_comm is None or _FI_WORKSPACE is None: + """FlashInfer fused allreduce + rmsnorm + FP4 quantization. + + Note: Only supported by the trtllm backend. + """ + if flashinfer_comm is None or workspace is None: raise RuntimeError("FlashInfer not available or workspace not initialized") if norm_out is None: @@ -241,7 +278,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant( flashinfer_comm.allreduce_fusion( input=input_tensor, - workspace=_FI_WORKSPACE, + workspace=workspace, pattern=flashinfer_comm.AllReduceFusionPattern.kARResidualRMSNormFP4Quant, residual_in=residual, residual_out=residual_out, @@ -253,7 +290,7 @@ def flashinfer_fused_allreduce_rmsnorm_fp4_quant( layout_code=flashinfer_comm.QuantizationSFLayout.SWIZZLED_128x4, scale_factor=input_global_scale, use_oneshot=use_oneshot, - **allreduce_params.get_trtllm_fused_allreduce_kwargs(), + **allreduce_params.get_flashinfer_fused_allreduce_kwargs(), ) @@ -386,13 +423,16 @@ def run_benchmarks( dtype: torch.dtype, use_residual: bool, allreduce_params: FlashInferFusedAllReduceParams | None, + workspaces: dict, quant_modes: set[str], no_oneshot: bool, ): """Run all benchmarks for given configuration. Args: - quant_mode: "none", "fp8_only", "fp4_only", or "all" + allreduce_params: Shared parameters for FlashInfer fused allreduce. + workspaces: Dict mapping backend name ("trtllm", "mnnvl") to workspace. + quant_modes: Set of quantization modes: "none", "fp8", "fp4". """ ( input_tensor, @@ -454,10 +494,11 @@ def run_benchmarks( logger.error("Standard AllReduce+RMSNorm Native Compiled failed: %s", e) results["standard_allreduce_rmsnorm_native_compiled"] = float("inf") - # FlashInfer Fused AllReduce + RMSNorm Oneshot/Twoshot - if flashinfer_comm is not None and allreduce_params is not None: + # FlashInfer Fused AllReduce + RMSNorm (all backends) + for backend, workspace in workspaces.items(): for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" + key = f"flashinfer_{backend}_fused_allreduce_rmsnorm{suffix}" try: time_ms = benchmark_operation( flashinfer_fused_allreduce_rmsnorm, @@ -467,14 +508,17 @@ def run_benchmarks( rms_gamma=rms_gamma, rms_eps=rms_eps, allreduce_params=allreduce_params, + workspace=workspace, use_oneshot=use_oneshot, ) - results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = time_ms + results[key] = time_ms except Exception as e: - logger.error("FlashInfer Fused AllReduce+RMSNorm failed: %s", e) - results[f"flashinfer_fused_allreduce_rmsnorm{suffix}"] = float( - "inf" + logger.error( + "FlashInfer (%s) Fused AllReduce+RMSNorm failed: %s", + backend, + e, ) + results[key] = float("inf") if "fp8" in quant_modes: # Standard AllReduce + RMSNorm + FP8 Quant @@ -540,10 +584,12 @@ def run_benchmarks( "inf" ) - # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant Oneshot - if flashinfer_comm is not None and allreduce_params is not None: + # FlashInfer Fused AllReduce + RMSNorm + FP8 Quant (trtllm only) + if "trtllm" in workspaces: + trtllm_ws = workspaces["trtllm"] for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" + key = f"flashinfer_trtllm_fused_allreduce_rmsnorm_fp8_quant{suffix}" try: time_ms = benchmark_operation( flashinfer_fused_allreduce_rmsnorm_fp8_quant, @@ -555,19 +601,16 @@ def run_benchmarks( scale_factor=scale_fp8, quant_out=quant_out_fp8, allreduce_params=allreduce_params, + workspace=trtllm_ws, use_oneshot=use_oneshot, ) - results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = ( - time_ms - ) + results[key] = time_ms except Exception as e: logger.error( - "FlashInfer Fused AllReduce+RMSNorm+FP8 Oneshot failed: %s", + "FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP8 failed: %s", e, ) - results[f"flashinfer_fused_allreduce_rmsnorm_fp8_quant{suffix}"] = ( - float("inf") - ) + results[key] = float("inf") if "fp4" in quant_modes and current_platform.has_device_capability(100): # Standard AllReduce + RMSNorm + FP4 Quant @@ -627,10 +670,12 @@ def run_benchmarks( "inf" ) - # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Oneshot - if flashinfer_comm is not None and allreduce_params is not None: + # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant (trtllm only) + if "trtllm" in workspaces: + trtllm_ws = workspaces["trtllm"] for use_oneshot in use_oneshot_options: suffix = "_oneshot" if use_oneshot else "_twoshot" + key = f"flashinfer_trtllm_fused_allreduce_rmsnorm_fp4_quant{suffix}" try: time_ms = benchmark_operation( flashinfer_fused_allreduce_rmsnorm_fp4_quant, @@ -641,49 +686,18 @@ def run_benchmarks( rms_eps=rms_eps, input_global_scale=scale_fp4, allreduce_params=allreduce_params, + workspace=trtllm_ws, quant_out=fp4_quant_out, output_scale=fp4_output_scale, use_oneshot=use_oneshot, ) - results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = ( - time_ms - ) + results[key] = time_ms except Exception as e: logger.error( - "FlashInfer Fused AllReduce+RMSNorm+FP4 Oneshot failed: %s", + "FlashInfer (trtllm) Fused AllReduce+RMSNorm+FP4 failed: %s", e, ) - results[f"flashinfer_fused_allreduce_rmsnorm_fp4_quant{suffix}"] = ( - float("inf") - ) - - # FlashInfer Fused AllReduce + RMSNorm + FP4 Quant Two-shot - if flashinfer_comm is not None and allreduce_params is not None: - try: - time_ms = benchmark_operation( - flashinfer_fused_allreduce_rmsnorm_fp4_quant, - input_tensor, - residual=residual, - norm_out=norm_out, - rms_gamma=rms_gamma, - rms_eps=rms_eps, - input_global_scale=scale_fp4, - allreduce_params=allreduce_params, - quant_out=fp4_quant_out, - output_scale=fp4_output_scale, - use_oneshot=False, - ) - results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = ( - time_ms - ) - except Exception as e: - logger.error( - "FlashInfer Fused AllReduce+RMSNorm+FP4 Two-shot failed: %s", - e, - ) - results["flashinfer_fused_allreduce_rmsnorm_fp4_quant_twoshot"] = float( - "inf" - ) + results[key] = float("inf") return results @@ -1021,8 +1035,7 @@ def main(): configs = list(itertools.product(args.num_tokens, dtypes, residual_options)) - # Setup FlashInfer workspace if available - workspace = None + # Setup FlashInfer workspaces for all backends allreduce_params = None if flashinfer_comm is not None: @@ -1037,15 +1050,17 @@ def main(): args.hidden_dim * max_element_size ) - workspace = setup_flashinfer_workspace( - world_size, - rank, - args.hidden_dim, - max_num_token, - dtype=workspace_dtype, - ) + for backend in FLASHINFER_BACKENDS: + setup_flashinfer_workspace( + backend=backend, + world_size=world_size, + rank=rank, + hidden_dim=args.hidden_dim, + max_token_num=max_num_token, + dtype=workspace_dtype, + ) - if workspace is not None: + if _FI_WORKSPACES: allreduce_params = FlashInferFusedAllReduceParams( max_token_num=max_num_token, ) @@ -1071,6 +1086,7 @@ def main(): dtype, use_residual, allreduce_params, + workspaces=_FI_WORKSPACES, quant_modes=quant_modes, no_oneshot=args.no_oneshot, ) @@ -1109,11 +1125,13 @@ def main(): finally: # Cleanup - if workspace is not None: - cleanup_flashinfer_workspace(workspace) + cleanup_flashinfer_workspaces() dist.barrier() if __name__ == "__main__": - main() + from vllm.config import VllmConfig, set_current_vllm_config + + with set_current_vllm_config(VllmConfig()): + main() diff --git a/cmake/cpu_extension.cmake b/cmake/cpu_extension.cmake index 5a0980dcc965..dde8cc20751b 100644 --- a/cmake/cpu_extension.cmake +++ b/cmake/cpu_extension.cmake @@ -13,28 +13,16 @@ endif() # # Define environment variables for special configurations # -set(ENABLE_AVX2 $ENV{VLLM_CPU_AVX2}) -set(ENABLE_AVX512 $ENV{VLLM_CPU_AVX512}) -set(ENABLE_AVX512BF16 $ENV{VLLM_CPU_AVX512BF16}) -set(ENABLE_AVX512VNNI $ENV{VLLM_CPU_AVX512VNNI}) -set(ENABLE_AMXBF16 $ENV{VLLM_CPU_AMXBF16}) +set(ENABLE_X86_ISA $ENV{VLLM_CPU_X86}) set(ENABLE_ARM_BF16 $ENV{VLLM_CPU_ARM_BF16}) include_directories("${CMAKE_SOURCE_DIR}/csrc") - set (ENABLE_NUMA TRUE) # # Check the compile flags # - -if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64") - list(APPEND CXX_COMPILE_FLAGS - "-mf16c" - ) -endif() - if(MACOSX_FOUND) list(APPEND CXX_COMPILE_FLAGS "-DVLLM_CPU_EXTENSION") @@ -78,18 +66,6 @@ function(check_sysctl TARGET OUT) endif() endfunction() - -function (is_avx512_disabled OUT) - set(DISABLE_AVX512 $ENV{VLLM_CPU_DISABLE_AVX512}) - if(DISABLE_AVX512 AND DISABLE_AVX512 STREQUAL "true") - set(${OUT} ON PARENT_SCOPE) - else() - set(${OUT} OFF PARENT_SCOPE) - endif() -endfunction() - -is_avx512_disabled(AVX512_DISABLED) - if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") message(STATUS "Apple Silicon Detected") set(APPLE_SILICON_FOUND TRUE) @@ -97,8 +73,6 @@ if (MACOSX_FOUND AND CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") check_sysctl(hw.optional.neon ASIMD_FOUND) check_sysctl(hw.optional.arm.FEAT_BF16 ARM_BF16_FOUND) else() - find_isa(${CPUINFO} "avx2" AVX2_FOUND) - find_isa(${CPUINFO} "avx512f" AVX512_FOUND) find_isa(${CPUINFO} "Power11" POWER11_FOUND) find_isa(${CPUINFO} "POWER10" POWER10_FOUND) find_isa(${CPUINFO} "POWER9" POWER9_FOUND) @@ -108,77 +82,32 @@ else() find_isa(${CPUINFO} "v" RVV_FOUND) # Check for RISC-V RVV support # Support cross-compilation by allowing override via environment variables - if (ENABLE_AVX2) - set(AVX2_FOUND ON) - message(STATUS "AVX2 support enabled via VLLM_CPU_AVX2 environment variable") - endif() - if (ENABLE_AVX512) - set(AVX512_FOUND ON) - message(STATUS "AVX512 support enabled via VLLM_CPU_AVX512 environment variable") - endif() if (ENABLE_ARM_BF16) set(ARM_BF16_FOUND ON) message(STATUS "ARM BF16 support enabled via VLLM_CPU_ARM_BF16 environment variable") endif() endif() -if (AVX512_FOUND AND NOT AVX512_DISABLED) - list(APPEND CXX_COMPILE_FLAGS +if (CMAKE_SYSTEM_PROCESSOR MATCHES "x86_64|amd64" OR ENABLE_X86_ISA) + set(ENABLE_X86_ISA ON) + if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND + CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3)) + message(FATAL_ERROR "X86 backend requires gcc/g++ >= 12.3") + endif() + list(APPEND CXX_COMPILE_FLAGS "-mf16c") + list(APPEND CXX_COMPILE_FLAGS_AVX512 ${CXX_COMPILE_FLAGS}) + list(APPEND CXX_COMPILE_FLAGS_AVX2 ${CXX_COMPILE_FLAGS}) + list(APPEND CXX_COMPILE_FLAGS_AVX512 "-mavx512f" "-mavx512vl" "-mavx512bw" - "-mavx512dq") - - find_isa(${CPUINFO} "avx512_bf16" AVX512BF16_FOUND) - if (AVX512BF16_FOUND OR ENABLE_AVX512BF16) - if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND - CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) - list(APPEND CXX_COMPILE_FLAGS "-mavx512bf16") - set(ENABLE_AVX512BF16 ON) - else() - set(ENABLE_AVX512BF16 OFF) - message(WARNING "Disable AVX512-BF16 ISA support, requires gcc/g++ >= 12.3") - endif() - else() - set(ENABLE_AVX512BF16 OFF) - message(WARNING "Disable AVX512-BF16 ISA support, no avx512_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512BF16=1.") - endif() - - find_isa(${CPUINFO} "avx512_vnni" AVX512VNNI_FOUND) - if (AVX512VNNI_FOUND OR ENABLE_AVX512VNNI) - if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND - CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) - list(APPEND CXX_COMPILE_FLAGS "-mavx512vnni") - set(ENABLE_AVX512VNNI ON) - else() - set(ENABLE_AVX512VNNI OFF) - message(WARNING "Disable AVX512-VNNI ISA support, requires gcc/g++ >= 12.3") - endif() - else() - set(ENABLE_AVX512VNNI OFF) - message(WARNING "Disable AVX512-VNNI ISA support, no avx512_vnni found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AVX512VNNI=1.") - endif() - - find_isa(${CPUINFO} "amx_bf16" AMXBF16_FOUND) - if (AMXBF16_FOUND OR ENABLE_AMXBF16) - if (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND - CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 12.3) - list(APPEND CXX_COMPILE_FLAGS "-mamx-bf16" "-mamx-tile") - set(ENABLE_AMXBF16 ON) - add_compile_definitions(-DCPU_CAPABILITY_AMXBF16) - else() - set(ENABLE_AMXBF16 OFF) - message(WARNING "Disable AMX_BF16 ISA support, requires gcc/g++ >= 12.3") - endif() - else() - set(ENABLE_AMXBF16 OFF) - message(WARNING "Disable AMX_BF16 ISA support, no amx_bf16 found in local CPU flags." " If cross-compilation is required, please set env VLLM_CPU_AMXBF16=1.") - endif() - -elseif (AVX2_FOUND) - list(APPEND CXX_COMPILE_FLAGS "-mavx2") - message(WARNING "vLLM CPU backend using AVX2 ISA") - + "-mavx512dq" + "-mavx512bf16" + "-mavx512vnni" + "-mamx-bf16" + "-mamx-tile") + list(APPEND CXX_COMPILE_FLAGS_AVX2 + "-mavx2") elseif (POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) message(STATUS "PowerPC detected") if (POWER9_FOUND) @@ -219,12 +148,12 @@ elseif (CMAKE_SYSTEM_PROCESSOR MATCHES "riscv64") list(APPEND CXX_COMPILE_FLAGS "-march=rv64gc") endif() else() - message(FATAL_ERROR "vLLM CPU backend requires AVX512, AVX2, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.") + message(FATAL_ERROR "vLLM CPU backend requires X86, Power9+ ISA, S390X ISA, ARMv8 or RISC-V support.") endif() -# Build oneDNN for GEMM kernels (only for x86-AVX512 /ARM platforms) -if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) +# Build oneDNN for GEMM kernels +if (ENABLE_X86_ISA OR (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) OR POWER9_FOUND OR POWER10_FOUND OR POWER11_FOUND) # Fetch and build Arm Compute Library (ACL) as oneDNN's backend for AArch64 # TODO [fadara01]: remove this once ACL can be fetched and built automatically as a dependency of oneDNN set(ONEDNN_AARCH64_USE_ACL OFF CACHE BOOL "") @@ -329,13 +258,21 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON set(ONEDNN_ENABLE_WORKLOAD "INFERENCE") set(ONEDNN_ENABLE_PRIMITIVE "MATMUL;REORDER") set(ONEDNN_BUILD_GRAPH "OFF") - set(ONEDNN_ENABLE_JIT_PROFILING "OFF") + set(ONEDNN_ENABLE_JIT_PROFILING "ON") set(ONEDNN_ENABLE_ITT_TASKS "OFF") - set(ONEDNN_ENABLE_MAX_CPU_ISA "OFF") - set(ONEDNN_ENABLE_CPU_ISA_HINTS "OFF") - set(ONEDNN_VERBOSE "OFF") + set(ONEDNN_ENABLE_MAX_CPU_ISA "ON") + set(ONEDNN_ENABLE_CPU_ISA_HINTS "ON") + set(ONEDNN_VERBOSE "ON") set(CMAKE_POLICY_DEFAULT_CMP0077 NEW) + # TODO: Refactor this + if (ENABLE_X86_ISA) + # Note: only enable oneDNN for AVX512 + list(APPEND DNNL_COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX512}) + else() + list(APPEND DNNL_COMPILE_FLAGS ${CXX_COMPILE_FLAGS}) + endif() + set(VLLM_BUILD_TYPE ${CMAKE_BUILD_TYPE}) set(CMAKE_BUILD_TYPE "Release") # remove oneDNN debug symbols to reduce size FetchContent_MakeAvailable(oneDNN) @@ -348,14 +285,20 @@ if ((AVX512_FOUND AND NOT AVX512_DISABLED) OR (ASIMD_FOUND AND NOT APPLE_SILICON PRIVATE ${oneDNN_SOURCE_DIR}/src ) target_link_libraries(dnnl_ext dnnl torch) - target_compile_options(dnnl_ext PRIVATE ${CXX_COMPILE_FLAGS} -fPIC) + target_compile_options(dnnl_ext PRIVATE ${DNNL_COMPILE_FLAGS} -fPIC) list(APPEND LIBS dnnl_ext) set(USE_ONEDNN ON) else() set(USE_ONEDNN OFF) endif() -message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") +# TODO: Refactor this +if (ENABLE_X86_ISA) + message(STATUS "CPU extension (AVX512) compile flags: ${CXX_COMPILE_FLAGS_AVX512}") + message(STATUS "CPU extension (AVX2) compile flags: ${CXX_COMPILE_FLAGS_AVX2}") +else() + message(STATUS "CPU extension compile flags: ${CXX_COMPILE_FLAGS}") +endif() if(ENABLE_NUMA) list(APPEND LIBS numa) @@ -390,25 +333,6 @@ set(VLLM_EXT_SRC "csrc/cpu/cpu_attn.cpp" "csrc/cpu/torch_bindings.cpp") -if (AVX512_FOUND AND NOT AVX512_DISABLED) - set(VLLM_EXT_SRC - "csrc/cpu/shm.cpp" - "csrc/cpu/cpu_wna16.cpp" - "csrc/cpu/cpu_fused_moe.cpp" - ${VLLM_EXT_SRC}) - if (ENABLE_AVX512BF16 AND ENABLE_AVX512VNNI) - set(VLLM_EXT_SRC - "csrc/cpu/sgl-kernels/gemm.cpp" - "csrc/cpu/sgl-kernels/gemm_int8.cpp" - "csrc/cpu/sgl-kernels/gemm_fp8.cpp" - "csrc/cpu/sgl-kernels/moe.cpp" - "csrc/cpu/sgl-kernels/moe_int8.cpp" - "csrc/cpu/sgl-kernels/moe_fp8.cpp" - ${VLLM_EXT_SRC}) - add_compile_definitions(-DCPU_CAPABILITY_AVX512) - endif() -endif() - if (ASIMD_FOUND AND NOT APPLE_SILICON_FOUND) set(VLLM_EXT_SRC "csrc/cpu/shm.cpp" @@ -421,21 +345,83 @@ if(USE_ONEDNN) ${VLLM_EXT_SRC}) endif() -message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}") - -# -# Define extension targets -# +if (ENABLE_X86_ISA) + set(VLLM_EXT_SRC_AVX512 + "csrc/cpu/sgl-kernels/gemm.cpp" + "csrc/cpu/sgl-kernels/gemm_int8.cpp" + "csrc/cpu/sgl-kernels/gemm_fp8.cpp" + "csrc/cpu/sgl-kernels/moe.cpp" + "csrc/cpu/sgl-kernels/moe_int8.cpp" + "csrc/cpu/sgl-kernels/moe_fp8.cpp" + "csrc/cpu/shm.cpp" + "csrc/cpu/cpu_wna16.cpp" + "csrc/cpu/cpu_fused_moe.cpp" + "csrc/cpu/utils.cpp" + "csrc/cpu/cpu_attn.cpp" + "csrc/cpu/dnnl_kernels.cpp" + "csrc/cpu/torch_bindings.cpp" + # TODO: Remove these files + "csrc/cpu/activation.cpp" + "csrc/cpu/layernorm.cpp" + "csrc/cpu/mla_decode.cpp" + "csrc/cpu/pos_encoding.cpp" + "csrc/moe/dynamic_4bit_int_moe_cpu.cpp") + + set(VLLM_EXT_SRC_AVX2 + "csrc/cpu/utils.cpp" + "csrc/cpu/cpu_attn.cpp" + "csrc/cpu/torch_bindings.cpp" + # TODO: Remove these files + "csrc/cpu/activation.cpp" + "csrc/cpu/layernorm.cpp" + "csrc/cpu/mla_decode.cpp" + "csrc/cpu/pos_encoding.cpp" + "csrc/moe/dynamic_4bit_int_moe_cpu.cpp") + + message(STATUS "CPU extension (AVX512) source files: ${VLLM_EXT_SRC_AVX512}") + message(STATUS "CPU extension (AVX2) source files: ${VLLM_EXT_SRC_AVX2}") + + define_extension_target( + _C + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_EXT_SRC_AVX512} + LIBRARIES ${LIBS} + COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX512} + USE_SABI 3 + WITH_SOABI + ) -define_extension_target( - _C - DESTINATION vllm - LANGUAGE CXX - SOURCES ${VLLM_EXT_SRC} - LIBRARIES ${LIBS} - COMPILE_FLAGS ${CXX_COMPILE_FLAGS} - USE_SABI 3 - WITH_SOABI -) + # For SGL kernels + target_compile_definitions(_C PRIVATE "-DCPU_CAPABILITY_AVX512") + # For AMX kernels + target_compile_definitions(_C PRIVATE "-DCPU_CAPABILITY_AMXBF16") + + define_extension_target( + _C_AVX2 + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_EXT_SRC_AVX2} + LIBRARIES ${LIBS} + COMPILE_FLAGS ${CXX_COMPILE_FLAGS_AVX2} + USE_SABI 3 + WITH_SOABI + ) +else() + message(STATUS "CPU extension source files: ${VLLM_EXT_SRC}") + # + # Define extension targets + # + define_extension_target( + _C + DESTINATION vllm + LANGUAGE CXX + SOURCES ${VLLM_EXT_SRC} + LIBRARIES ${LIBS} + COMPILE_FLAGS ${CXX_COMPILE_FLAGS} + USE_SABI 3 + WITH_SOABI + ) +endif() message(STATUS "Enabling C extension.") diff --git a/cmake/external_projects/vllm_flash_attn.cmake b/cmake/external_projects/vllm_flash_attn.cmake index 41c4e308d0be..c206b9c39ee1 100644 --- a/cmake/external_projects/vllm_flash_attn.cmake +++ b/cmake/external_projects/vllm_flash_attn.cmake @@ -17,7 +17,8 @@ endif() # They should be identical but if they aren't, this is a massive footgun. # # The vllm-flash-attn install rules are nested under vllm to make sure the library gets installed in the correct place. -# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2) or --component _vllm_fa3_C (for FA3). +# To only install vllm-flash-attn, use --component _vllm_fa2_C (for FA2), --component _vllm_fa3_C (for FA3), +# or --component _vllm_fa4_cutedsl_C (for FA4 CuteDSL Python files). # If no component is specified, vllm-flash-attn is still installed. # If VLLM_FLASH_ATTN_SRC_DIR is set, vllm-flash-attn is installed from that directory instead of downloading. @@ -38,7 +39,7 @@ else() FetchContent_Declare( vllm-flash-attn GIT_REPOSITORY https://github.com/vllm-project/flash-attention.git - GIT_TAG 5824e6e2008271063c3229ab3e7032bd74abbbc6 + GIT_TAG 140c00c0241bb60cc6e44e7c1be9998d4b20d8d2 GIT_PROGRESS TRUE # Don't share the vllm-flash-attn build between build types BINARY_DIR ${CMAKE_BINARY_DIR}/vllm-flash-attn @@ -46,38 +47,62 @@ else() endif() -# Ensure the vllm/vllm_flash_attn directory exists before installation -install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" ALL_COMPONENTS) - -# Make sure vllm-flash-attn install rules are nested under vllm/ -# This is here to support installing all components under the same prefix with cmake --install. -# setup.py installs every component separately but uses the same prefix for all. -# ALL_COMPONENTS is used to avoid duplication for FA2 and FA3, -# and these statements don't hurt when installing neither component. -install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" ALL_COMPONENTS) -install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS) -install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" ALL_COMPONENTS) +# Install rules for FA components need the install prefix nested under vllm/ +# These run at install time, before the FA library's own install rules +foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C) + install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY FALSE)" COMPONENT ${_FA_COMPONENT}) + install(CODE "set(OLD_CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}\")" COMPONENT ${_FA_COMPONENT}) + install(CODE "set(CMAKE_INSTALL_PREFIX \"\${CMAKE_INSTALL_PREFIX}/vllm/\")" COMPONENT ${_FA_COMPONENT}) +endforeach() # Fetch the vllm-flash-attn library FetchContent_MakeAvailable(vllm-flash-attn) message(STATUS "vllm-flash-attn is available at ${vllm-flash-attn_SOURCE_DIR}") -# Restore the install prefix -install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" ALL_COMPONENTS) -install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" ALL_COMPONENTS) +# Restore the install prefix after FA's install rules +foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C) + install(CODE "set(CMAKE_INSTALL_PREFIX \"\${OLD_CMAKE_INSTALL_PREFIX}\")" COMPONENT ${_FA_COMPONENT}) + install(CODE "set(CMAKE_INSTALL_LOCAL_ONLY TRUE)" COMPONENT ${_FA_COMPONENT}) +endforeach() + +# Install shared Python files for both FA2 and FA3 components +foreach(_FA_COMPONENT _vllm_fa2_C _vllm_fa3_C) + # Ensure the vllm/vllm_flash_attn directory exists before installation + install(CODE "file(MAKE_DIRECTORY \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn\")" + COMPONENT ${_FA_COMPONENT}) + + # Copy vllm_flash_attn python files (except __init__.py and flash_attn_interface.py + # which are source-controlled in vllm) + install( + DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ + DESTINATION vllm/vllm_flash_attn + COMPONENT ${_FA_COMPONENT} + FILES_MATCHING PATTERN "*.py" + PATTERN "__init__.py" EXCLUDE + PATTERN "flash_attn_interface.py" EXCLUDE + ) + +endforeach() -# Copy over the vllm-flash-attn python files (duplicated for fa2 and fa3, in -# case only one is built, in the case both are built redundant work is done) -install( - DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ - DESTINATION vllm/vllm_flash_attn - COMPONENT _vllm_fa2_C - FILES_MATCHING PATTERN "*.py" -) +# +# FA4 CuteDSL component +# This is a Python-only component that copies the flash_attn/cute directory +# and transforms imports to match our package structure. +# +add_custom_target(_vllm_fa4_cutedsl_C) -install( - DIRECTORY ${vllm-flash-attn_SOURCE_DIR}/vllm_flash_attn/ - DESTINATION vllm/vllm_flash_attn - COMPONENT _vllm_fa3_C - FILES_MATCHING PATTERN "*.py" -) +# Copy flash_attn/cute directory (needed for FA4) and transform imports +# The cute directory uses flash_attn.cute imports internally, which we replace +# with vllm.vllm_flash_attn.cute to match our package structure. +install(CODE " + file(GLOB_RECURSE CUTE_PY_FILES \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute/*.py\") + foreach(SRC_FILE \${CUTE_PY_FILES}) + file(RELATIVE_PATH REL_PATH \"${vllm-flash-attn_SOURCE_DIR}/flash_attn/cute\" \${SRC_FILE}) + set(DST_FILE \"\${CMAKE_INSTALL_PREFIX}/vllm/vllm_flash_attn/cute/\${REL_PATH}\") + get_filename_component(DST_DIR \${DST_FILE} DIRECTORY) + file(MAKE_DIRECTORY \${DST_DIR}) + file(READ \${SRC_FILE} FILE_CONTENTS) + string(REPLACE \"flash_attn.cute\" \"vllm.vllm_flash_attn.cute\" FILE_CONTENTS \"\${FILE_CONTENTS}\") + file(WRITE \${DST_FILE} \"\${FILE_CONTENTS}\") + endforeach() +" COMPONENT _vllm_fa4_cutedsl_C) diff --git a/csrc/activation_kernels.cu b/csrc/activation_kernels.cu index 99fa42f75e99..758a77795553 100644 --- a/csrc/activation_kernels.cu +++ b/csrc/activation_kernels.cu @@ -5,117 +5,11 @@ #include #include "cuda_compat.h" +#include "cuda_vec_utils.cuh" #include "dispatch_utils.h" namespace vllm { -struct alignas(32) u32x8_t { - uint32_t u0, u1, u2, u3, u4, u5, u6, u7; -}; - -__device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \ - defined(CUDA_VERSION) && CUDA_VERSION >= 12090 - asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n" - : "=r"(val.u0), "=r"(val.u1), "=r"(val.u2), "=r"(val.u3), - "=r"(val.u4), "=r"(val.u5), "=r"(val.u6), "=r"(val.u7) - : "l"(ptr)); -#else - const uint4* uint_ptr = reinterpret_cast(ptr); - uint4 top_half = __ldg(&uint_ptr[0]); - uint4 bottom_half = __ldg(&uint_ptr[1]); - val.u0 = top_half.x; - val.u1 = top_half.y; - val.u2 = top_half.z; - val.u3 = top_half.w; - val.u4 = bottom_half.x; - val.u5 = bottom_half.y; - val.u6 = bottom_half.z; - val.u7 = bottom_half.w; -#endif -} - -__device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) { -#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \ - defined(CUDA_VERSION) && CUDA_VERSION >= 12090 - asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n" - : - : "l"(ptr), "r"(val.u0), "r"(val.u1), "r"(val.u2), "r"(val.u3), - "r"(val.u4), "r"(val.u5), "r"(val.u6), "r"(val.u7) - : "memory"); -#else - uint4* uint_ptr = reinterpret_cast(ptr); - uint_ptr[0] = make_uint4(val.u0, val.u1, val.u2, val.u3); - uint_ptr[1] = make_uint4(val.u4, val.u5, val.u6, val.u7); -#endif -} - -template -struct VecTraits; - -template <> -struct VecTraits { - static constexpr int ARCH_MAX_VEC_SIZE = 32; - using vec_t = u32x8_t; -}; - -template <> -struct VecTraits { - static constexpr int ARCH_MAX_VEC_SIZE = 16; - using vec_t = int4; -}; - -template -struct PackedTraits; - -template <> -struct PackedTraits { - using packed_t = __nv_bfloat162; -}; - -template <> -struct PackedTraits { - using packed_t = __half2; -}; - -template <> -struct PackedTraits { - using packed_t = float2; -}; - -template -__device__ __forceinline__ float2 cast_to_float2(const packed_t& val) { - if constexpr (std::is_same_v) { - return __bfloat1622float2(val); - } else if constexpr (std::is_same_v) { - return __half22float2(val); - } else if constexpr (std::is_same_v) { - return float2(val); - } -} - -template -__device__ __forceinline__ packed_t cast_to_packed(const float2& val) { - if constexpr (std::is_same_v) { - return __float22bfloat162_rn(val); - } else if constexpr (std::is_same_v) { - return __float22half2_rn(val); - } else if constexpr (std::is_same_v) { - return float2(val); - } -} - -template -__device__ __forceinline__ packed_t packed_mul(const packed_t& x, - const packed_t& y) { - if constexpr (std::is_same_v || - std::is_same_v) { - return __hmul2(x, y); - } else if constexpr (std::is_same_v) { - return make_float2(x.x * y.x, x.y * y.y); - } -} - template __device__ __forceinline__ scalar_t compute(const scalar_t& x, @@ -131,16 +25,6 @@ __device__ __forceinline__ packed_t packed_compute(const packed_t& x, : packed_mul(x, PACKED_ACT_FN(y)); } -// Check if all pointers are 16-byte aligned for int4 vectorized access -__host__ __device__ __forceinline__ bool is_16byte_aligned(const void* ptr) { - return (reinterpret_cast(ptr) & 15) == 0; -} - -// Check if all pointers are 16-byte aligned for longlong4_32a vectorized access -__host__ __device__ __forceinline__ bool is_32byte_aligned(const void* ptr) { - return (reinterpret_cast(ptr) & 31) == 0; -} - // Activation and gating kernel template. template ::vec_t; - constexpr int ARCH_MAX_VEC_SIZE = VecTraits::ARCH_MAX_VEC_SIZE; - constexpr int VEC_SIZE = ARCH_MAX_VEC_SIZE / sizeof(packed_t); + using cuda_t = typename CUDATypeConverter::Type; + using pvec_t = PackedVec; - const vec_t* x_vec = reinterpret_cast(x_ptr); - const vec_t* y_vec = reinterpret_cast(y_ptr); - vec_t* out_vec = reinterpret_cast(out_ptr); - const int num_vecs = d / 2 / VEC_SIZE; + const pvec_t* x_vec = reinterpret_cast(x_ptr); + const pvec_t* y_vec = reinterpret_cast(y_ptr); + pvec_t* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / 2 / pvec_t::NUM_ELTS; for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { - vec_t x, y; + pvec_t x, y; if constexpr (use_256b) { ld256(x, &x_vec[i]); ld256(y, &y_vec[i]); } else { - x = VLLM_LDG(&x_vec[i]); - y = VLLM_LDG(&y_vec[i]); + ld128(x, &x_vec[i]); + ld128(y, &y_vec[i]); } - auto* xp = reinterpret_cast(&x); - auto* yp = reinterpret_cast(&y); #pragma unroll - for (int j = 0; j < VEC_SIZE; j++) { - xp[j] = - packed_compute(xp[j], yp[j]); + for (int j = 0; j < pvec_t::NUM_ELTS; j++) { + x.elts[j] = packed_compute( + x.elts[j], y.elts[j]); } if constexpr (use_256b) { st256(x, &out_vec[i]); } else { - out_vec[i] = x; + st128(x, &out_vec[i]); } } } else { @@ -272,51 +152,54 @@ packed_gelu_tanh_kernel(const packed_t& val) { // Launch activation and gating kernel. // Use ACT_FIRST (bool) indicating whether to apply the activation function // first. -#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \ - auto dtype = input.scalar_type(); \ - int d = input.size(-1) / 2; \ - int64_t num_tokens = input.numel() / input.size(-1); \ - if (num_tokens == 0) { \ - return; \ - } \ - dim3 grid(num_tokens); \ - int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ - int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ - int vec_size = support_vec / at::elementSize(dtype); \ - const bool use_vec = (d % vec_size == 0); \ - const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ - const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ - if (use_vec) { \ - dim3 block(std::min(d / vec_size, 1024)); \ - if (cc_major >= 10 && num_tokens > 128) { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ - KERNEL, \ - PACKED_KERNEL::packed_t>, \ - ACT_FIRST, true, true><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ - }); \ - } else { \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ - KERNEL, \ - PACKED_KERNEL::packed_t>, \ - ACT_FIRST, true, false><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ - }); \ - } \ - } else { \ - dim3 block(std::min(d, 1024)); \ - VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ - vllm::act_and_mul_kernel< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ - KERNEL, \ - PACKED_KERNEL::packed_t>, \ - ACT_FIRST, false><<>>( \ - out.data_ptr(), input.data_ptr(), d); \ - }); \ +#define LAUNCH_ACTIVATION_GATE_KERNEL(KERNEL, PACKED_KERNEL, ACT_FIRST) \ + auto dtype = input.scalar_type(); \ + int d = input.size(-1) / 2; \ + int64_t num_tokens = input.numel() / input.size(-1); \ + if (num_tokens == 0) { \ + return; \ + } \ + dim3 grid(num_tokens); \ + int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ + int support_vec = \ + (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ + ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ + : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ + int vec_size = support_vec / at::elementSize(dtype); \ + const bool use_vec = (d % vec_size == 0); \ + const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ + const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ + if (use_vec) { \ + dim3 block(std::min(d / vec_size, 1024)); \ + if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL::Type>, \ + ACT_FIRST, true, true><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ + } else { \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL::Type>, \ + ACT_FIRST, true, false><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ + } \ + } else { \ + dim3 block(std::min(d, 1024)); \ + VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel", [&] { \ + vllm::act_and_mul_kernel< \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ + KERNEL, \ + PACKED_KERNEL::Type>, \ + ACT_FIRST, false><<>>( \ + out.data_ptr(), input.data_ptr(), d); \ + }); \ } void silu_and_mul(torch::Tensor& out, // [..., d] @@ -378,35 +261,31 @@ __global__ void act_and_mul_kernel_with_param( scalar_t* out_ptr = out + blockIdx.x * d; if constexpr (use_vec) { - // Fast path: 128-bit/256-bit vectorized loop - using vec_t = typename VecTraits::vec_t; - constexpr int ARCH_MAX_VEC_SIZE = VecTraits::ARCH_MAX_VEC_SIZE; - constexpr int VEC_SIZE = ARCH_MAX_VEC_SIZE / sizeof(packed_t); + using cuda_t = typename CUDATypeConverter::Type; + using pvec_t = PackedVec; - const vec_t* x_vec = reinterpret_cast(x_ptr); - const vec_t* y_vec = reinterpret_cast(y_ptr); - vec_t* out_vec = reinterpret_cast(out_ptr); - const int num_vecs = d / 2 / VEC_SIZE; + const pvec_t* x_vec = reinterpret_cast(x_ptr); + const pvec_t* y_vec = reinterpret_cast(y_ptr); + pvec_t* out_vec = reinterpret_cast(out_ptr); + const int num_vecs = d / 2 / pvec_t::NUM_ELTS; for (int i = threadIdx.x; i < num_vecs; i += blockDim.x) { - vec_t x, y; + pvec_t x, y; if constexpr (use_256b) { ld256(x, &x_vec[i]); ld256(y, &y_vec[i]); } else { - x = VLLM_LDG(&x_vec[i]); - y = VLLM_LDG(&y_vec[i]); + ld128(x, &x_vec[i]); + ld128(y, &y_vec[i]); } - auto* xp = reinterpret_cast(&x); - auto* yp = reinterpret_cast(&y); #pragma unroll - for (int j = 0; j < VEC_SIZE; j++) { - xp[j] = packed_mul(PACKED_ACT_FN(xp[j], param), yp[j]); + for (int j = 0; j < pvec_t::NUM_ELTS; j++) { + x.elts[j] = packed_mul(PACKED_ACT_FN(x.elts[j], param), y.elts[j]); } if constexpr (use_256b) { st256(x, &out_vec[i]); } else { - out_vec[i] = x; + st128(x, &out_vec[i]); } } } else { @@ -499,21 +378,24 @@ __global__ void swigluoai_and_mul_kernel( } \ dim3 grid(num_tokens); \ int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ - int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ + int support_vec = \ + (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ + ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ + : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ int vec_size = support_vec / at::elementSize(dtype); \ const bool use_vec = (d % vec_size == 0); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ if (use_vec) { \ dim3 block(std::min(d / vec_size, 1024)); \ - if (cc_major >= 10 && num_tokens > 128) { \ + if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ VLLM_DISPATCH_FLOATING_TYPES( \ dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ PACKED_KERNEL< \ - typename vllm::PackedTraits::packed_t>, \ + typename vllm::PackedTypeConverter::Type>, \ true, true><<>>( \ out.data_ptr(), input.data_ptr(), d, \ PARAM); \ @@ -522,10 +404,10 @@ __global__ void swigluoai_and_mul_kernel( VLLM_DISPATCH_FLOATING_TYPES( \ dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ PACKED_KERNEL< \ - typename vllm::PackedTraits::packed_t>, \ + typename vllm::PackedTypeConverter::Type>, \ true, false><<>>( \ out.data_ptr(), input.data_ptr(), d, \ PARAM); \ @@ -535,9 +417,9 @@ __global__ void swigluoai_and_mul_kernel( dim3 block(std::min(d, 1024)); \ VLLM_DISPATCH_FLOATING_TYPES(dtype, "act_and_mul_kernel_with_param", [&] { \ vllm::act_and_mul_kernel_with_param< \ - scalar_t, typename vllm::PackedTraits::packed_t, \ + scalar_t, typename vllm::PackedTypeConverter::Type, \ KERNEL, \ - PACKED_KERNEL::packed_t>, \ + PACKED_KERNEL::Type>, \ false><<>>( \ out.data_ptr(), input.data_ptr(), d, PARAM); \ }); \ @@ -629,14 +511,17 @@ __global__ void activation_kernel( } \ dim3 grid(num_tokens); \ int cc_major = at::cuda::getCurrentDeviceProperties()->major; \ - int support_vec = (cc_major >= 10 && num_tokens > 128) ? 32 : 16; \ + int support_vec = \ + (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) \ + ? vllm::VecTraits::ARCH_MAX_VEC_SIZE \ + : vllm::VecTraits::ARCH_MAX_VEC_SIZE; \ int vec_size = support_vec / at::elementSize(dtype); \ const bool use_vec = (d % vec_size == 0); \ const at::cuda::OptionalCUDAGuard device_guard(device_of(input)); \ const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); \ if (use_vec) { \ dim3 block(std::min(d / vec_size, 1024)); \ - if (cc_major >= 10 && num_tokens > 128) { \ + if (CUDA_VERSION >= 12090 && cc_major >= 10 && num_tokens > 128) { \ VLLM_DISPATCH_FLOATING_TYPES(dtype, "activation_kernel", [&] { \ vllm::activation_kernel, true, true> \ <<>>(out.data_ptr(), \ diff --git a/csrc/cpu/torch_bindings.cpp b/csrc/cpu/torch_bindings.cpp index 11e1305c6027..2ea482148d4c 100644 --- a/csrc/cpu/torch_bindings.cpp +++ b/csrc/cpu/torch_bindings.cpp @@ -4,6 +4,10 @@ #include +// Note: overwrite the external defination for sharing same name between +// libraries use different ISAs. +#define TORCH_EXTENSION_NAME _C + std::string init_cpu_threads_env(const std::string& cpu_ids); void release_dnnl_matmul_handler(int64_t handler); @@ -324,19 +328,12 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "str act, str isa) -> ()"); ops.impl("cpu_fused_moe", torch::kCPU, &cpu_fused_moe); #endif -} - -TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _utils), utils) { - // CPU utils - utils.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env); -} - -TORCH_LIBRARY_EXPAND(CONCAT(TORCH_EXTENSION_NAME, _cpu), cpu_ops) { - cpu_ops.def( + ops.def("init_cpu_threads_env(str cpu_ids) -> str", &init_cpu_threads_env); + ops.def( "mla_decode_kvcache(" " Tensor! out, Tensor query, Tensor kv_cache," " float scale, Tensor block_tables, Tensor seq_lens) -> ()"); - cpu_ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache); + ops.impl("mla_decode_kvcache", torch::kCPU, &mla_decode_kvcache); } REGISTER_EXTENSION(TORCH_EXTENSION_NAME) diff --git a/csrc/cuda_vec_utils.cuh b/csrc/cuda_vec_utils.cuh new file mode 100644 index 000000000000..82a19f10a70e --- /dev/null +++ b/csrc/cuda_vec_utils.cuh @@ -0,0 +1,334 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +#pragma once + +#include +#include +#include + +#ifdef USE_ROCM + #include +#else + #include + #include + #include +#endif + +// Device-side: SM100+ architecture with CUDA 12.9+ toolkit, which +// together enable 256-bit (v8.u32) PTX load/store instructions. +// Use for PTX instruction selection with architecture fallback paths. +#if !defined(USE_ROCM) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 && \ + defined(CUDA_VERSION) && CUDA_VERSION >= 12090 + #define VLLM_256B_PTX_ENABLED 1 +#else + #define VLLM_256B_PTX_ENABLED 0 +#endif + +namespace vllm { + +// ============================================================ +// Types and traits +// ============================================================ + +// 256-bit (32-byte) aligned vector type: 8 x uint32_t +struct alignas(32) u32x8_t { + uint32_t d[8]; +}; + +// VecTraits — select between 128-bit (int4) and 256-bit +// (u32x8_t) vector types at compile time. +template +struct VecTraits; + +template <> +struct VecTraits { + static constexpr int ARCH_MAX_VEC_SIZE = 32; + using vec_t = u32x8_t; +}; + +template <> +struct VecTraits { + static constexpr int ARCH_MAX_VEC_SIZE = 16; + using vec_t = int4; +}; + +// PackedTypeConverter — map between CUDA scalar and packed types +// half <-> half2, __nv_bfloat16 <-> __nv_bfloat162, etc. +template +struct PackedTypeConverter { + static_assert(sizeof(T) == 0, + "PackedTypeConverter is not specialized for this type."); +}; + +template <> +struct PackedTypeConverter { + using Type = half; +}; + +template <> +struct PackedTypeConverter { + using Type = half2; +}; + +template <> +struct PackedTypeConverter<__nv_bfloat162> { + using Type = __nv_bfloat16; +}; + +template <> +struct PackedTypeConverter<__nv_bfloat16> { + using Type = __nv_bfloat162; +}; + +template <> +struct PackedTypeConverter { + using Type = float2; +}; + +template <> +struct PackedTypeConverter { + using Type = float; +}; + +template <> +struct PackedTypeConverter { + using Type = half2; +}; + +template <> +struct PackedTypeConverter { + using Type = __nv_bfloat162; +}; + +// CUDATypeConverter — map PyTorch scalar types to CUDA scalar +// c10::Half -> half, c10::BFloat16 -> __nv_bfloat16 +template +struct CUDATypeConverter { + using Type = T; +}; + +template <> +struct CUDATypeConverter { + using Type = half; +}; + +template <> +struct CUDATypeConverter { + using Type = __nv_bfloat16; +}; + +// PackedVec — typed vector container for packed element access. +// Derives alignment and element count from VecTraits. +// Type is the CUDA scalar type (e.g. half, __nv_bfloat16). +template +struct alignas(VecTraits::ARCH_MAX_VEC_SIZE) PackedVec { + static constexpr int NUM_ELTS = + VecTraits::ARCH_MAX_VEC_SIZE / + sizeof(typename PackedTypeConverter::Type); + typename PackedTypeConverter::Type elts[NUM_ELTS]; +}; + +// ============================================================ +// Load / store primitives +// ============================================================ + +// 256-bit load / store — SM100+ only (PTX v8 instructions). +__device__ __forceinline__ void ld256(u32x8_t& val, const u32x8_t* ptr) { +#if VLLM_256B_PTX_ENABLED + asm volatile("ld.global.nc.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];\n" + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) + : "l"(ptr)); +#else + assert(false && "ld256 requires SM100+ with CUDA 12.9+"); +#endif +} + +__device__ __forceinline__ void st256(u32x8_t& val, u32x8_t* ptr) { +#if VLLM_256B_PTX_ENABLED + asm volatile("st.global.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};\n" + : + : "l"(ptr), "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), + "r"(val.d[3]), "r"(val.d[4]), "r"(val.d[5]), "r"(val.d[6]), + "r"(val.d[7]) + : "memory"); +#else + assert(false && "st256 requires SM100+ with CUDA 12.9+"); +#endif +} + +// Generic ld256 / st256 for any 32-byte aligned type (e.g. PackedVec). +// Non-template overloads above are preferred for u32x8_t. +template +__device__ __forceinline__ void ld256(T& val, const T* ptr) { + static_assert(sizeof(T) == 32, "ld256 requires a 32-byte type"); + ld256(reinterpret_cast(val), reinterpret_cast(ptr)); +} + +template +__device__ __forceinline__ void st256(T& val, T* ptr) { + static_assert(sizeof(T) == 32, "st256 requires a 32-byte type"); + st256(reinterpret_cast(val), reinterpret_cast(ptr)); +} + +// 128-bit load / store via __ldg (read-only cache hint). +template +__device__ __forceinline__ void ld128(T& val, const T* ptr) { + static_assert(sizeof(T) == 16, "ld128 requires a 16-byte type"); + *reinterpret_cast(&val) = __ldg(reinterpret_cast(ptr)); +} + +template +__device__ __forceinline__ void st128(T& val, T* ptr) { + static_assert(sizeof(T) == 16, "st128 requires a 16-byte type"); + *reinterpret_cast(ptr) = *reinterpret_cast(&val); +} + +// 256-bit cache-streaming (.cs) load / store — SM100+ only. +__forceinline__ __device__ u32x8_t ld256_cs(const u32x8_t* addr) { +#if VLLM_256B_PTX_ENABLED + u32x8_t val; + asm volatile("ld.global.cs.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%8];" + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) + : "l"(addr)); + return val; +#else + assert(false && "ld256_cs requires SM100+ with CUDA 12.9+"); + return {}; +#endif +} + +__forceinline__ __device__ void st256_cs(u32x8_t* addr, u32x8_t val) { +#if VLLM_256B_PTX_ENABLED + asm volatile( + "st.global.cs.v8.u32 [%0], {%1,%2,%3,%4,%5,%6,%7,%8};" ::"l"(addr), + "r"(val.d[0]), "r"(val.d[1]), "r"(val.d[2]), "r"(val.d[3]), "r"(val.d[4]), + "r"(val.d[5]), "r"(val.d[6]), "r"(val.d[7])); +#else + assert(false && "st256_cs requires SM100+ with CUDA 12.9+"); +#endif +} + +// 32-bit cache-streaming (.cs) load / store — SM100+ only. +__forceinline__ __device__ int ld32_cs(const int* addr) { +#if VLLM_256B_PTX_ENABLED + int val; + asm volatile("ld.global.cs.b32 %0, [%1];" : "=r"(val) : "l"(addr)); + return val; +#else + assert(false && "ld32_cs requires SM100+ with CUDA 12.9+"); + return 0; +#endif +} + +__forceinline__ __device__ void st32_cs(int* addr, int val) { +#if VLLM_256B_PTX_ENABLED + asm volatile("st.global.cs.b32 [%0], %1;" ::"l"(addr), "r"(val)); +#else + assert(false && "st32_cs requires SM100+ with CUDA 12.9+"); +#endif +} + +// Predicated 256-bit / 128-bit cache-global (.cg) loads. +// Returns zero if pred is false. SM100+ only. +__device__ __forceinline__ void ld256_cg_or_zero(u32x8_t& val, const void* ptr, + bool pred) { +#if VLLM_256B_PTX_ENABLED + asm volatile( + "{\n" + " .reg .pred pr;\n" + " setp.ne.u32 pr, %8, 0;\n" + " mov.u32 %0, 0;\n" + " mov.u32 %1, 0;\n" + " mov.u32 %2, 0;\n" + " mov.u32 %3, 0;\n" + " mov.u32 %4, 0;\n" + " mov.u32 %5, 0;\n" + " mov.u32 %6, 0;\n" + " mov.u32 %7, 0;\n" + " @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n" + "}\n" + : "=r"(val.d[0]), "=r"(val.d[1]), "=r"(val.d[2]), "=r"(val.d[3]), + "=r"(val.d[4]), "=r"(val.d[5]), "=r"(val.d[6]), "=r"(val.d[7]) + : "r"((int)pred), "l"(ptr)); +#else + assert(false && "ld256_cg_or_zero requires SM100+ with CUDA 12.9+"); +#endif +} + +__device__ __forceinline__ void ld128_cg_or_zero(uint4& val, const void* ptr, + bool pred) { +#if VLLM_256B_PTX_ENABLED + uint32_t r0, r1, r2, r3; + + asm volatile( + "{\n" + " .reg .pred pr;\n" + " setp.ne.u32 pr, %4, 0;\n" + " mov.u32 %0, 0;\n" + " mov.u32 %1, 0;\n" + " mov.u32 %2, 0;\n" + " mov.u32 %3, 0;\n" + " @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n" + "}\n" + : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) + : "r"((int)pred), "l"(ptr)); + + val = uint4{r0, r1, r2, r3}; +#else + assert(false && "ld128_cg_or_zero requires SM100+ with CUDA 12.9+"); +#endif +} + +// ============================================================ +// Alignment helpers +// ============================================================ + +__host__ __device__ __forceinline__ bool is_16byte_aligned(const void* ptr) { + return (reinterpret_cast(ptr) & 15) == 0; +} + +__host__ __device__ __forceinline__ bool is_32byte_aligned(const void* ptr) { + return (reinterpret_cast(ptr) & 31) == 0; +} + +// ============================================================ +// Packed type conversion and arithmetic +// ============================================================ + +template +__device__ __forceinline__ float2 cast_to_float2(const packed_t& val) { + if constexpr (std::is_same_v) { + return __bfloat1622float2(val); + } else if constexpr (std::is_same_v) { + return __half22float2(val); + } else if constexpr (std::is_same_v) { + return float2(val); + } +} + +template +__device__ __forceinline__ packed_t cast_to_packed(const float2& val) { + if constexpr (std::is_same_v) { + return __float22bfloat162_rn(val); + } else if constexpr (std::is_same_v) { + return __float22half2_rn(val); + } else if constexpr (std::is_same_v) { + return float2(val); + } +} + +template +__device__ __forceinline__ packed_t packed_mul(const packed_t& x, + const packed_t& y) { + if constexpr (std::is_same_v || + std::is_same_v) { + return __hmul2(x, y); + } else if constexpr (std::is_same_v) { + return make_float2(x.x * y.x, x.y * y.y); + } +} + +} // namespace vllm diff --git a/csrc/mamba/mamba_ssm/selective_scan.h b/csrc/mamba/mamba_ssm/selective_scan.h index 7d22dd8b84a3..8f33c7cfa163 100644 --- a/csrc/mamba/mamba_ssm/selective_scan.h +++ b/csrc/mamba/mamba_ssm/selective_scan.h @@ -15,9 +15,9 @@ //////////////////////////////////////////////////////////////////////////////////////////////////// struct SSMParamsBase { - using index_t = uint32_t; + using index_t = size_t; - int batch, dim, seqlen, dstate, n_groups, n_chunks; + int batch, dim, seqlen, dstate, n_groups; int dim_ngroups_ratio; bool is_variable_B; bool is_variable_C; @@ -72,6 +72,8 @@ struct SSMParamsBase { void *__restrict__ block_idx_first_scheduled_token_ptr; // (batch,) - first block to write void *__restrict__ block_idx_last_scheduled_token_ptr; // (batch,) - last block to write void *__restrict__ initial_state_idx_ptr; // (batch,) - index of the initial state to use + void *__restrict__ cu_chunk_seqlen_ptr; // (nchunks+1,) - cumulative chunk token offsets + void *__restrict__ last_chunk_indices_ptr; // (batch,) - index of last chunk per sequence }; diff --git a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu index fb2a2e578999..d852a0ed4928 100644 --- a/csrc/mamba/mamba_ssm/selective_scan_fwd.cu +++ b/csrc/mamba/mamba_ssm/selective_scan_fwd.cu @@ -81,7 +81,6 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { constexpr bool kIsVariableC = Ktraits::kIsVariableC; constexpr bool kHasZ = Ktraits::kHasZ; constexpr bool kVarlen = Ktraits::kVarlen; - constexpr int kNThreads = Ktraits::kNThreads; constexpr int kNItems = Ktraits::kNItems; constexpr int kNRows = Ktraits::kNRows; constexpr bool kDirectIO = Ktraits::kDirectIO; @@ -161,17 +160,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } - - // for (int state_idx = threadIdx.x; state_idx < params.dstate; state_idx += blockDim.x) { - // smem_a[state_idx] = A[state_idx * params.A_dstate_stride]; - // smem_bc[state_idx] = B[state_idx * params.B_dstate_stride] * C[state_idx * params.C_dstate_stride]; - // } - - constexpr int kChunkSize = kNThreads * kNItems; - // Use block_size for chunking when APC is enabled, otherwise use 2048 for backwards compatibility - const int iteration_chunk_size = params.cache_enabled ? params.block_size : 2048; - const int n_chunks = (seqlen + iteration_chunk_size - 1) / iteration_chunk_size; + const int block_size = params.cache_enabled ? params.block_size : 2048; const int* batch_cache_indices = cache_indices != nullptr ? cache_indices + batch_id * params.cache_indices_stride : nullptr; @@ -181,10 +171,44 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { reinterpret_cast(params.block_idx_last_scheduled_token_ptr) : nullptr; const int* initial_state_idx = params.initial_state_idx_ptr != nullptr ? reinterpret_cast(params.initial_state_idx_ptr) : nullptr; + const int* cu_chunk_seqlen = params.cu_chunk_seqlen_ptr != nullptr ? + reinterpret_cast(params.cu_chunk_seqlen_ptr) : nullptr; + const int* last_chunk_indices = params.last_chunk_indices_ptr != nullptr ? + reinterpret_cast(params.last_chunk_indices_ptr) : nullptr; const size_t load_cache_slot = params.cache_enabled && batch_cache_indices != nullptr ? batch_cache_indices[initial_state_idx[batch_id]] : cache_index; + const int block_idx_first = (params.cache_enabled && block_idx_first_scheduled != nullptr) ? + block_idx_first_scheduled[batch_id] : 0; + + // Determine chunk boundaries from pre-computed metadata (APC mode) + // or fall back to simple block_size chunking. + int first_chunk_idx, n_chunks; + int current_position; + + if (cu_chunk_seqlen != nullptr && last_chunk_indices != nullptr) { + const int last_chunk_idx = last_chunk_indices[batch_id]; + first_chunk_idx = (batch_id == 0) ? 0 : last_chunk_indices[batch_id - 1] + 1; + n_chunks = last_chunk_idx - first_chunk_idx + 1; + // Derive current_position: if the first chunk is partial (fills remainder + // of a started block), offset into the block accordingly. + const int first_chunk_tokens = cu_chunk_seqlen[first_chunk_idx + 1] - cu_chunk_seqlen[first_chunk_idx]; + const int chunk_start_offset = (n_chunks > 1 && first_chunk_tokens < block_size) + ? (block_size - first_chunk_tokens) : 0; + current_position = block_idx_first * block_size + chunk_start_offset; + } else { + first_chunk_idx = 0; + n_chunks = (seqlen + block_size - 1) / block_size; + current_position = 0; + } + + int tokens_processed = 0; + for (int chunk = 0; chunk < n_chunks; ++chunk) { + const int chunk_tokens = (cu_chunk_seqlen != nullptr) + ? cu_chunk_seqlen[first_chunk_idx + chunk + 1] - cu_chunk_seqlen[first_chunk_idx + chunk] + : min(block_size, seqlen - tokens_processed); + if (chunk_tokens <= 0) break; input_t u_vals[kNRows][kNItems], delta_vals_load[kNRows][kNItems]; __syncthreads(); @@ -193,12 +217,12 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if constexpr (!kDirectIO) { if (r > 0) { __syncthreads(); } } - load_input(u + r * params.u_d_stride, u_vals[r], smem_load, seqlen - chunk * kChunkSize); + load_input(u + r * params.u_d_stride, u_vals[r], smem_load, chunk_tokens); if constexpr (!kDirectIO) { __syncthreads(); } - load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, seqlen - chunk * kChunkSize); + load_input(delta + r * params.delta_d_stride, delta_vals_load[r], smem_load, chunk_tokens); } - u += kChunkSize; - delta += kChunkSize; + u += chunk_tokens; + delta += chunk_tokens; float delta_vals[kNRows][kNItems], delta_u_vals[kNRows][kNItems], out_vals[kNRows][kNItems]; #pragma unroll @@ -232,7 +256,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { weight_t B_vals[kNItems], C_vals[kNItems]; if constexpr (kIsVariableB) { load_weight(Bvar + state_idx * params.B_dstate_stride, B_vals, - smem_load_weight, (seqlen - chunk * kChunkSize) * (1)); + smem_load_weight, chunk_tokens); if constexpr (!kIsVariableC) { #pragma unroll for (int r = 0; r < kNRows; ++r) { @@ -243,7 +267,7 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if constexpr (kIsVariableC) { auto &smem_load_weight_C = !kIsVariableB ? smem_load_weight : smem_load_weight1; load_weight(Cvar + state_idx * params.C_dstate_stride, C_vals, - smem_load_weight_C, (seqlen - chunk * kChunkSize) * (1)); + smem_load_weight_C, chunk_tokens); if constexpr (!kIsVariableB) { #pragma unroll for (int r = 0; r < kNRows; ++r) { @@ -266,10 +290,8 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { for (int i = 0; i < kNItems; ++i) { thread_data[i] = make_float2(exp2f(delta_vals[r][i] * A_val[r]), !kIsVariableB ? delta_u_vals[r][i] : B_vals[i] * delta_u_vals[r][i]); - if (seqlen % (kNItems * kNThreads) != 0) { // So that the last state is correct - if (threadIdx.x * kNItems + i >= seqlen - chunk * kChunkSize) { - thread_data[i] = make_float2(1.f, 0.f); - } + if (threadIdx.x * kNItems + i >= chunk_tokens) { + thread_data[i] = make_float2(1.f, 0.f); } } // Initialize running total @@ -301,14 +323,14 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { if (threadIdx.x == 0) { smem_running_prefix[state_idx + r * MAX_DSTATE] = prefix_op.running_prefix; - // Store state at the end of each chunk when cache is enabled + // Store state at the end of each aligned chunk when cache is enabled if (params.cache_enabled && batch_cache_indices != nullptr) { - size_t cache_slot; if (chunk == n_chunks - 1) { cache_slot = batch_cache_indices[block_idx_last_scheduled[batch_id]]; } else { - cache_slot = batch_cache_indices[block_idx_first_scheduled[batch_id] + chunk]; + const int block_idx_completed = (current_position + chunk_tokens - 1) / block_size; + cache_slot = batch_cache_indices[block_idx_completed]; } size_t state_offset = cache_slot * params.ssm_states_batch_stride + @@ -331,38 +353,41 @@ void selective_scan_fwd_kernel(SSMParamsBase params) { } } input_t *out = reinterpret_cast(params.out_ptr) + sequence_start_index * params.out_batch_stride - + dim_id * kNRows * params.out_d_stride + chunk * kChunkSize; + + dim_id * kNRows * params.out_d_stride + tokens_processed; __syncthreads(); #pragma unroll for (int r = 0; r < kNRows; ++r) { if constexpr (!kDirectIO) { if (r > 0) { __syncthreads(); } } - store_output(out + r * params.out_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize); + store_output(out + r * params.out_d_stride, out_vals[r], smem_store, chunk_tokens); } if constexpr (kHasZ) { input_t *z = reinterpret_cast(params.z_ptr) + sequence_start_index * params.z_batch_stride - + dim_id * kNRows * params.z_d_stride + chunk * kChunkSize; + + dim_id * kNRows * params.z_d_stride + tokens_processed; input_t *out_z = reinterpret_cast(params.out_z_ptr) + sequence_start_index * params.out_z_batch_stride - + dim_id * kNRows * params.out_z_d_stride + chunk * kChunkSize; + + dim_id * kNRows * params.out_z_d_stride + tokens_processed; #pragma unroll for (int r = 0; r < kNRows; ++r) { input_t z_vals[kNItems]; __syncthreads(); - load_input(z + r * params.z_d_stride, z_vals, smem_load, seqlen - chunk * kChunkSize); + load_input(z + r * params.z_d_stride, z_vals, smem_load, chunk_tokens); #pragma unroll for (int i = 0; i < kNItems; ++i) { float z_val = z_vals[i]; out_vals[r][i] *= z_val / (1 + expf(-z_val)); } __syncthreads(); - store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, seqlen - chunk * kChunkSize); + store_output(out_z + r * params.out_z_d_stride, out_vals[r], smem_store, chunk_tokens); } } - Bvar += kChunkSize * 1; - Cvar += kChunkSize * 1; + Bvar += chunk_tokens; + Cvar += chunk_tokens; + + tokens_processed += chunk_tokens; + current_position += chunk_tokens; } } @@ -506,7 +531,9 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, int64_t block_size, const std::optional &block_idx_first_scheduled_token, const std::optional &block_idx_last_scheduled_token, - const std::optional &initial_state_idx) { + const std::optional &initial_state_idx, + const std::optional &cu_chunk_seqlen, + const std::optional &last_chunk_indices) { // Reset the parameters memset(¶ms, 0, sizeof(params)); @@ -548,6 +575,8 @@ void set_ssm_params_fwd(SSMParamsBase ¶ms, params.block_idx_first_scheduled_token_ptr = block_idx_first_scheduled_token.has_value() ? block_idx_first_scheduled_token.value().data_ptr() : nullptr; params.block_idx_last_scheduled_token_ptr = block_idx_last_scheduled_token.has_value() ? block_idx_last_scheduled_token.value().data_ptr() : nullptr; params.initial_state_idx_ptr = initial_state_idx.has_value() ? initial_state_idx.value().data_ptr() : nullptr; + params.cu_chunk_seqlen_ptr = cu_chunk_seqlen.has_value() ? cu_chunk_seqlen.value().data_ptr() : nullptr; + params.last_chunk_indices_ptr = last_chunk_indices.has_value() ? last_chunk_indices.value().data_ptr() : nullptr; // All stride are in elements, not bytes. params.A_d_stride = A.stride(0); @@ -633,7 +662,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, int64_t block_size, const std::optional &block_idx_first_scheduled_token, const std::optional &block_idx_last_scheduled_token, - const std::optional &initial_state_idx) { + const std::optional &initial_state_idx, + const std::optional &cu_chunk_seqlen, + const std::optional &last_chunk_indices) { auto input_type = u.scalar_type(); auto weight_type = A.scalar_type(); TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16); @@ -778,7 +809,9 @@ void selective_scan_fwd(const torch::Tensor &u, const torch::Tensor &delta, block_size, block_idx_first_scheduled_token, block_idx_last_scheduled_token, - initial_state_idx + initial_state_idx, + cu_chunk_seqlen, + last_chunk_indices ); diff --git a/csrc/moe/moe_ops.h b/csrc/moe/moe_ops.h index b71db3569447..d8d962887dab 100644 --- a/csrc/moe/moe_ops.h +++ b/csrc/moe/moe_ops.h @@ -58,6 +58,10 @@ void shuffle_rows(const torch::Tensor& input_tensor, torch::Tensor& output_tensor); #ifndef USE_ROCM +// cuBLAS bf16 x bf16 -> fp32 router GEMM (fallback for non-SM90 / batch > 16) +torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input, + torch::Tensor const& weight); + // DeepSeek V3 optimized router GEMM kernel for SM90+ // Computes output = mat_a @ mat_b.T where: // mat_a: [num_tokens, hidden_dim] in bf16 diff --git a/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu b/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu new file mode 100644 index 000000000000..f507f9299b03 --- /dev/null +++ b/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm.cu @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +// Adapted from SGLang: +// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled.cu + +#include + +#include "cutlass_mxfp8_grouped_mm_launcher.cuh" + +void cutlass_mxfp8_grouped_mm(const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& sfa, + const torch::Tensor& sfb, torch::Tensor& d, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, + const torch::Tensor& blockscale_offsets) { +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); + TORCH_CHECK(problem_sizes.size(1) == 3, + "problem_sizes must have shape (num_experts, 3)"); + TORCH_CHECK(problem_sizes.size(0) == expert_offsets.size(0), + "Number of experts in problem_sizes must match expert_offsets"); + TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, + "problem_sizes must be int32"); + TORCH_CHECK(expert_offsets.dtype() == torch::kInt32, + "expert_offsets must be int32"); + TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32, + "blockscale_offsets must be int32"); + TORCH_CHECK(a.dim() == 2, "a must be a 2D tensor of shape (num_tokens, k)"); + TORCH_CHECK(b.dim() == 3, + "b must be a 3D tensor of shape (num_experts, k, n)"); + TORCH_CHECK(a.size(1) == b.size(1) && a.size(1) % 128 == 0, + "k should align 128"); + TORCH_CHECK(b.size(2) % 128 == 0, "n should align 128"); + TORCH_CHECK(a.strides()[1] == 1, "a must be row major"); + TORCH_CHECK(b.strides()[1] == 1, "b must be column major"); + + auto stream = at::cuda::getCurrentCUDAStream(); + if (d.dtype() == torch::kBFloat16) { + expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype< + cutlass::bfloat16_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets, + blockscale_offsets, stream); + } else if (d.dtype() == torch::kFloat16) { + expert_specialization::cutlass_mxfp8_grouped_mm_dispatch_out_dtype< + cutlass::half_t>(a, b, sfa, sfb, d, problem_sizes, expert_offsets, + blockscale_offsets, stream); + } else { + TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16"); + } +#else + TORCH_CHECK(false, + "No implemented cutlass_mxfp8_grouped_mm for " + "current device"); +#endif +} + +#include "core/registration.h" + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("cutlass_mxfp8_grouped_mm", cutlass_mxfp8_grouped_mm); +} \ No newline at end of file diff --git a/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh b/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh new file mode 100644 index 000000000000..9fb1dbf8eef5 --- /dev/null +++ b/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_functor.cuh @@ -0,0 +1,141 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +// Adapted from SGLang: +// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_functor.cuh + +#pragma once +#include + +#include "cute/tensor.hpp" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass_mxfp8_grouped_mm_traits.cuh" + +namespace expert_specialization { + +using namespace cute; + +template +struct CutlassMxfp8GroupedMmOffsetFunctor { + using Gemm = typename GemmTraits::Gemm; + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementSF = typename GemmTraits::ElementSF; + using ElementD = typename GemmTraits::ElementOutput; + // Input + int* expert_offsets{nullptr}; + int* blockscale_offsets{nullptr}; + // Output + ElementA* a_base{nullptr}; + ElementB* b_base{nullptr}; + ElementSF* sfa_base{nullptr}; + ElementSF* sfb_base{nullptr}; + ElementD* d_base{nullptr}; + ElementA** a_offsets{nullptr}; + ElementB** b_offsets{nullptr}; + ElementSF** sfa_offsets{nullptr}; + ElementSF** sfb_offsets{nullptr}; + ElementD** d_offsets{nullptr}; + + CutlassMxfp8GroupedMmOffsetFunctor() = default; + CutlassMxfp8GroupedMmOffsetFunctor( + int* _expert_offsets, int* _blockscale_offsets, ElementA* _a_base, + ElementB* _b_base, ElementSF* _sfa_base, ElementSF* _sfb_base, + ElementD* _d_base, ElementA** _a_offsets, ElementB** _b_offsets, + ElementSF** _sfa_offsets, ElementSF** _sfb_offsets, ElementD** _d_offsets) + : expert_offsets{_expert_offsets}, + blockscale_offsets{_blockscale_offsets}, + a_base(_a_base), + b_base(_b_base), + sfa_base(_sfa_base), + sfb_base(_sfb_base), + d_base(_d_base), + a_offsets(_a_offsets), + b_offsets(_b_offsets), + sfa_offsets(_sfa_offsets), + sfb_offsets(_sfb_offsets), + d_offsets(_d_offsets) {} + + void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { + int64_t expert_offset = static_cast(expert_offsets[expert_id]); + int64_t blockscale_offset = + static_cast(blockscale_offsets[expert_id]); + int64_t a_stride = expert_offset * k; + int64_t b_stride = expert_id * k * n; + int64_t d_stride = expert_offset * n; + int64_t sfa_stride = blockscale_offset * (k / 32); + int64_t sfb_stride = expert_id * n * (k / 32); + + a_offsets[expert_id] = a_base + a_stride; + b_offsets[expert_id] = b_base + b_stride; + sfa_offsets[expert_id] = sfa_base + sfa_stride; + sfb_offsets[expert_id] = sfb_base + sfb_stride; + d_offsets[expert_id] = d_base + d_stride; + } +}; + +template +struct CutlassMxfp8GroupedMmLayoutFunctor { + using Sm1xxBlkScaledConfig = typename GemmTraits::Sm1xxBlkScaledConfig; + using LayoutSFA = typename GemmTraits::LayoutSFA; + using LayoutSFB = typename GemmTraits::LayoutSFB; + LayoutSFA* layout_sfa_base{nullptr}; + LayoutSFB* layout_sfb_base{nullptr}; + + CutlassMxfp8GroupedMmLayoutFunctor() = default; + CutlassMxfp8GroupedMmLayoutFunctor(LayoutSFA* _layout_sfa_base, + LayoutSFB* _layout_sfb_base) + : layout_sfa_base(_layout_sfa_base), layout_sfb_base(_layout_sfb_base) {} + + void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { + LayoutSFA* layout_sfa_ptr = layout_sfa_base + expert_id; + LayoutSFB* layout_sfb_ptr = layout_sfb_base + expert_id; + *layout_sfa_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFA( + cute::make_shape(m, n, k, 1)); + *layout_sfb_ptr = Sm1xxBlkScaledConfig::tile_atom_to_shape_SFB( + cute::make_shape(m, n, k, 1)); + } +}; + +template +struct CutlassMxfp8GroupedMmStrideFunctor { + using StrideA = typename GemmTraits::StrideA; + using StrideB = typename GemmTraits::StrideB; + using StrideD = typename GemmTraits::StrideD; + StrideA* stride_A_base{nullptr}; + StrideB* stride_B_base{nullptr}; + StrideD* stride_D_base{nullptr}; + + CutlassMxfp8GroupedMmStrideFunctor() = default; + CutlassMxfp8GroupedMmStrideFunctor(StrideA* _stride_A_base, + StrideB* _stride_B_base, + StrideD* _stride_D_base) + : stride_A_base(_stride_A_base), + stride_B_base(_stride_B_base), + stride_D_base(_stride_D_base) {} + + void CUTE_DEVICE operator()(int64_t expert_id, int m, int n, int k) { + StrideA* stride_A = stride_A_base + expert_id; + StrideB* stride_B = stride_B_base + expert_id; + StrideD* stride_D = stride_D_base + expert_id; + *stride_A = cutlass::make_cute_packed_stride(StrideA{}, {m, k, 1}); + *stride_B = cutlass::make_cute_packed_stride(StrideB{}, {n, k, 1}); + *stride_D = cutlass::make_cute_packed_stride(StrideD{}, {m, n, 1}); + } +}; + +template +__global__ void cutlassMxfp8GroupedMmPreComputeKernel( + int* problem_sizes, OffsetFunctor offset_functor, + LayoutFunctor layout_functor, StrideFunctor stride_functor) { + int64_t expert_id = static_cast(threadIdx.x); + int m = problem_sizes[expert_id * 3 + 0]; + int n = problem_sizes[expert_id * 3 + 1]; + int k = problem_sizes[expert_id * 3 + 2]; + + offset_functor(expert_id, m, n, k); + layout_functor(expert_id, m, n, k); + stride_functor(expert_id, m, n, k); +} + +} // namespace expert_specialization \ No newline at end of file diff --git a/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh b/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh new file mode 100644 index 000000000000..2c46e1fa7252 --- /dev/null +++ b/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_launcher.cuh @@ -0,0 +1,179 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +// Adapted from SGLang: +// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_launcher.cuh + +#pragma once +#include +#include +#include + +#include +#include +#include + +#include "cute/tensor.hpp" +#include "cutlass_mxfp8_grouped_mm_functor.cuh" +#include "cutlass_mxfp8_grouped_mm_traits.cuh" + +namespace expert_specialization { + +template +void cutlass_mxfp8_grouped_mm_pre_compute( + torch::Tensor& a_ptrs, torch::Tensor& b_ptrs, torch::Tensor& sfa_ptrs, + torch::Tensor& sfb_ptrs, torch::Tensor& d_ptrs, torch::Tensor& stride_a, + torch::Tensor& stride_b, torch::Tensor& stride_d, torch::Tensor& layout_sfa, + torch::Tensor& layout_sfb, const torch::Tensor& a, const torch::Tensor& b, + const torch::Tensor& sfa, const torch::Tensor& sfb, const torch::Tensor& d, + const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets, + const torch::Tensor& blockscale_offsets, cudaStream_t stream) { + using OffsetFunctor = CutlassMxfp8GroupedMmOffsetFunctor; + using ElementA = typename OffsetFunctor::ElementA; + using ElementB = typename OffsetFunctor::ElementB; + using ElementSF = typename OffsetFunctor::ElementSF; + using ElementD = typename OffsetFunctor::ElementD; + + using LayoutFunctor = CutlassMxfp8GroupedMmLayoutFunctor; + using LayoutSFA = typename LayoutFunctor::LayoutSFA; + using LayoutSFB = typename LayoutFunctor::LayoutSFB; + + using StrideFunctor = CutlassMxfp8GroupedMmStrideFunctor; + using StrideA = typename StrideFunctor::StrideA; + using StrideB = typename StrideFunctor::StrideB; + using StrideD = typename StrideFunctor::StrideD; + + int num_experts = (int)expert_offsets.size(0); + TORCH_CHECK(num_experts <= 1024, + "Number of experts cannot exceed 1024, the maximum number of " + "threads per block."); + + OffsetFunctor offset_functor( + reinterpret_cast(expert_offsets.data_ptr()), + reinterpret_cast(blockscale_offsets.data_ptr()), + reinterpret_cast(a.data_ptr()), + reinterpret_cast(b.data_ptr()), + reinterpret_cast(sfa.data_ptr()), + reinterpret_cast(sfb.data_ptr()), + reinterpret_cast(d.data_ptr()), + reinterpret_cast(a_ptrs.data_ptr()), + reinterpret_cast(b_ptrs.data_ptr()), + reinterpret_cast(sfa_ptrs.data_ptr()), + reinterpret_cast(sfb_ptrs.data_ptr()), + reinterpret_cast(d_ptrs.data_ptr())); + LayoutFunctor layout_functor( + reinterpret_cast(layout_sfa.data_ptr()), + reinterpret_cast(layout_sfb.data_ptr())); + StrideFunctor stride_functor(reinterpret_cast(stride_a.data_ptr()), + reinterpret_cast(stride_b.data_ptr()), + reinterpret_cast(stride_d.data_ptr())); + cutlassMxfp8GroupedMmPreComputeKernel<<<1, num_experts, 0, stream>>>( + static_cast(problem_sizes.data_ptr()), offset_functor, + layout_functor, stride_functor); +} + +template +void cutlass_mxfp8_grouped_mm( + const torch::Tensor& a_ptrs, const torch::Tensor& b_ptrs, + const torch::Tensor& sfa_ptrs, const torch::Tensor& sfb_ptrs, + const torch::Tensor& d_ptrs, const torch::Tensor& stride_a, + const torch::Tensor& stride_b, const torch::Tensor& stride_d, + const torch::Tensor& layout_sfa, const torch::Tensor& layout_sfb, + const torch::Tensor& problem_sizes, cudaStream_t stream) { + using Gemm = typename GemmTraits::Gemm; + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementSF = typename GemmTraits::ElementSF; + using ElementD = typename GemmTraits::ElementOutput; + using StrideA = typename GemmTraits::StrideA; + using StrideB = typename GemmTraits::StrideB; + using StrideD = typename GemmTraits::StrideD; + using LayoutSFA = typename GemmTraits::LayoutSFA; + using LayoutSFB = typename GemmTraits::LayoutSFB; + using UnderlyingProblemShape = + typename GemmTraits::ProblemShape::UnderlyingProblemShape; + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = c10::cuda::current_device(); + hw_info.sm_count = + at::cuda::getCurrentDeviceProperties()->multiProcessorCount; + hw_info.cluster_shape = GemmTraits::MMAConfig::preferred_cluster; + hw_info.cluster_shape_fallback = GemmTraits::MMAConfig::fallback_cluster; + + int num_experts = (int)problem_sizes.size(0); + + UnderlyingProblemShape* underlying_problem_shape = + reinterpret_cast(problem_sizes.data_ptr()); + + typename Gemm::Arguments arguments = { + cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, underlying_problem_shape, nullptr}, + {reinterpret_cast(a_ptrs.data_ptr()), + reinterpret_cast(stride_a.data_ptr()), + reinterpret_cast(b_ptrs.data_ptr()), + reinterpret_cast(stride_b.data_ptr()), + reinterpret_cast(sfa_ptrs.data_ptr()), + reinterpret_cast(layout_sfa.data_ptr()), + reinterpret_cast(sfb_ptrs.data_ptr()), + reinterpret_cast(layout_sfb.data_ptr())}, + {{}, + nullptr, + nullptr, + reinterpret_cast(d_ptrs.data_ptr()), + reinterpret_cast(stride_d.data_ptr())}, + hw_info, + {} // Scheduler + }; + + Gemm gemm; + + auto can_implement_status = gemm.can_implement(arguments); + TORCH_CHECK(can_implement_status == cutlass::Status::kSuccess, + "Failed to implement GEMM"); + + torch::TensorOptions options_uint8 = + torch::TensorOptions().dtype(torch::kUInt8).device(d_ptrs.device()); + size_t workspace_size = gemm.get_workspace_size(arguments); + torch::Tensor workspace = torch::empty(workspace_size, options_uint8); + + auto status = gemm.initialize(arguments, workspace.data_ptr(), stream); + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to initialize GEMM"); + + status = gemm.run(stream, nullptr, true); // Enable PDL + TORCH_CHECK(status == cutlass::Status::kSuccess, "Failed to run GEMM"); +} + +template +void cutlass_mxfp8_grouped_mm_dispatch_out_dtype( + const torch::Tensor& a, const torch::Tensor& b, const torch::Tensor& sfa, + const torch::Tensor& sfb, torch::Tensor& d, + const torch::Tensor& problem_sizes, const torch::Tensor& expert_offsets, + const torch::Tensor& blockscale_offsets, cudaStream_t stream) { + int num_experts = (int)problem_sizes.size(0); + torch::TensorOptions options_int64 = + torch::TensorOptions().dtype(torch::kInt64).device(a.device()); + torch::TensorOptions options_int32 = + torch::TensorOptions().dtype(torch::kInt32).device(a.device()); + + torch::Tensor a_ptrs = torch::empty(num_experts, options_int64); + torch::Tensor b_ptrs = torch::empty(num_experts, options_int64); + torch::Tensor sfa_ptrs = torch::empty(num_experts, options_int64); + torch::Tensor sfb_ptrs = torch::empty(num_experts, options_int64); + torch::Tensor d_ptrs = torch::empty(num_experts, options_int64); + + torch::Tensor stride_a = torch::empty(num_experts, options_int64); + torch::Tensor stride_b = torch::empty(num_experts, options_int64); + torch::Tensor stride_d = torch::empty(num_experts, options_int64); + torch::Tensor layout_sfa = torch::empty({num_experts, 5}, options_int32); + torch::Tensor layout_sfb = torch::empty({num_experts, 5}, options_int32); + + using GemmTraits = CutlassMxfp8GroupedMmGemmTraits; + cutlass_mxfp8_grouped_mm_pre_compute( + a_ptrs, b_ptrs, sfa_ptrs, sfb_ptrs, d_ptrs, stride_a, stride_b, stride_d, + layout_sfa, layout_sfb, a, b, sfa, sfb, d, problem_sizes, expert_offsets, + blockscale_offsets, stream); + cutlass_mxfp8_grouped_mm( + a_ptrs, b_ptrs, sfa_ptrs, sfb_ptrs, d_ptrs, stride_a, stride_b, stride_d, + layout_sfa, layout_sfb, problem_sizes, stream); +} + +} // namespace expert_specialization \ No newline at end of file diff --git a/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh b/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh new file mode 100644 index 000000000000..ed8cd7ce0658 --- /dev/null +++ b/csrc/moe/mxfp8_moe/cutlass_mxfp8_grouped_mm_traits.cuh @@ -0,0 +1,127 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +// Adapted from SGLang: +// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_traits.cuh + +#pragma once + +// Misc +#include "cute/tensor.hpp" +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/detail/sm100_blockscaled_layout.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/layout/layout.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_size.h" + +// Collective Builder +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/gemm/collective/collective_builder.hpp" + +// Integration +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +namespace expert_specialization { + +using namespace cute; + +// Different configs for 1SM and 2SM MMA kernel +struct MMA1SMConfig { + using MmaTileShape = Shape<_128, _128, _128>; + using KernelSchedule = + cutlass::gemm::KernelPtrArrayTmaWarpSpecialized1SmMxf8f6f4Sm100; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecialized1Sm; + const static dim3 preferred_cluster; + const static dim3 fallback_cluster; +}; +const dim3 MMA1SMConfig::preferred_cluster(1, 4, 1); +const dim3 MMA1SMConfig::fallback_cluster(1, 2, 1); + +template +struct CutlassMxfp8GroupedMmGemmTraits { + using MMAConfig = _MMAConfig; + using ElementInput = cutlass::float_e4m3_t; + using ElementOutput = OutputDtype; + using ProblemShape = cutlass::gemm::GroupProblemShape>; + + // A matrix configuration + using ElementA = cutlass::mx_float8_t; + using LayoutA = cutlass::layout::RowMajor; + constexpr static int AlignmentA = 32; + + // B matrix configuration + using ElementB = cutlass::mx_float8_t; + using LayoutB = cutlass::layout::ColumnMajor; + constexpr static int AlignmentB = 32; + + // C/D matrix configuration + using ElementC = void; + using ElementD = ElementOutput; + using LayoutC = cutlass::layout::RowMajor; + using LayoutD = cutlass::layout::RowMajor; + constexpr static int AlignmentC = 128 / cutlass::sizeof_bits::value; + constexpr static int AlignmentD = 128 / cutlass::sizeof_bits::value; + using ElementAccumulator = float; + + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + using CustomEVTIdentity = // acc + cutlass::epilogue::fusion::Sm90EVT< + cutlass::epilogue::fusion::Sm90Compute< + cutlass::epilogue::thread::Identity, ElementD, ElementAccumulator, + RoundStyle>, + cutlass::epilogue::fusion::Sm90AccFetch>; + + // Core kernel configurations + using ArchTag = cutlass::arch::Sm100; + using OperatorClass = cutlass::arch::OpClassBlockScaledTensorOp; + using StageCountType = cutlass::gemm::collective::StageCountAuto; + + // Runtime Cluster Shape + using ClusterShape = Shape; + + // Define Epilogue + using CollectiveEpilogue = + typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, typename MMAConfig::MmaTileShape, + ClusterShape, Shape<_64, _64>, ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, AlignmentC, ElementD, LayoutD*, AlignmentD, + typename MMAConfig::EpilogueSchedule, + CustomEVTIdentity>::CollectiveOp; + + // Define Mainloop + using CollectiveMainloop = + typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, + LayoutB*, AlignmentB, ElementAccumulator, + typename MMAConfig::MmaTileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + typename MMAConfig::KernelSchedule>::CollectiveOp; + + // Define GemmKernel + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using ElementSF = typename Gemm::GemmKernel::ElementSF; + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + using StrideD = typename Gemm::GemmKernel::InternalStrideD; + using LayoutSFA = + typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFA; + using LayoutSFB = + typename Gemm::GemmKernel::CollectiveMainloop::InternalLayoutSFB; + using Sm1xxBlkScaledConfig = + typename Gemm::GemmKernel::CollectiveMainloop::Sm1xxBlkScaledConfig; +}; + +} // namespace expert_specialization \ No newline at end of file diff --git a/csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu b/csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu new file mode 100644 index 000000000000..2a93ab94d5ca --- /dev/null +++ b/csrc/moe/mxfp8_moe/mxfp8_experts_quant.cu @@ -0,0 +1,60 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +// Adapted from SGLang: +// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cu + +#include + +#include "mxfp8_experts_quant.cuh" + +void mxfp8_experts_quant(const torch::Tensor& input, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, + const torch::Tensor& blockscale_offsets, + torch::Tensor& quant_output, + torch::Tensor& scale_factor) { +#if defined(CUTLASS_ARCH_MMA_SM100_SUPPORTED) + TORCH_CHECK(input.dim() == 2, "input must be 2D tensor"); + TORCH_CHECK(input.size(1) % 128 == 0, "k must align to 128"); + TORCH_CHECK(input.strides()[1] == 1, "input must be row major"); + TORCH_CHECK(problem_sizes.dim() == 2, "problem_sizes must be 2D tensor"); + TORCH_CHECK(problem_sizes.dtype() == torch::kInt32, + "problem_sizes must be int32"); + TORCH_CHECK(expert_offsets.dtype() == torch::kInt32, + "expert_offsets must be int32"); + TORCH_CHECK(blockscale_offsets.dtype() == torch::kInt32, + "blockscale_offsets must be int32"); + + auto groups = problem_sizes.size(0); + TORCH_CHECK( + expert_offsets.dim() == 1 && expert_offsets.size(0) == groups, + "expert_offsets must be 1D and have size equal to the number of groups"); + TORCH_CHECK( + blockscale_offsets.dim() == 1 && blockscale_offsets.size(0) == groups, + "blockscale_offsets must be 1D and have size equal to the number of " + "groups"); + + auto stream = at::cuda::getCurrentCUDAStream(); + if (input.dtype() == torch::kBFloat16) { + expert_specialization::launch_mxfp8_experts_quant<__nv_bfloat16>( + input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, + scale_factor); + } else if (input.dtype() == torch::kFloat16) { + expert_specialization::launch_mxfp8_experts_quant<__half>( + input, problem_sizes, expert_offsets, blockscale_offsets, quant_output, + scale_factor); + } else { + TORCH_CHECK(false, "dtype must be kFloat16 or kBFloat16"); + } +#else + TORCH_CHECK(false, + "No implemented mxfp8_experts_quant for " + "current device"); +#endif +} + +#include "core/registration.h" + +TORCH_LIBRARY_IMPL_EXPAND(TORCH_EXTENSION_NAME, CUDA, m) { + m.impl("mxfp8_experts_quant", mxfp8_experts_quant); +} \ No newline at end of file diff --git a/csrc/moe/mxfp8_moe/mxfp8_experts_quant.cuh b/csrc/moe/mxfp8_moe/mxfp8_experts_quant.cuh new file mode 100644 index 000000000000..9a85852080fb --- /dev/null +++ b/csrc/moe/mxfp8_moe/mxfp8_experts_quant.cuh @@ -0,0 +1,414 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project +// Adapted from SGLang: +// https://github.com/sgl-project/sglang/blob/ded068a76e00878881d52d5bfb791e0f60d7311b/sgl-kernel/csrc/expert_specialization/es_sm100_mxfp8_blockscaled_group_quant.cuh + +#pragma once +#include +#include +#include +#include +#include +#include + +#include + +#include "cute/tensor.hpp" + +namespace expert_specialization { + +using namespace cute; + +constexpr uint32_t THREAD_BLOCK_SIZE = 128; +constexpr uint32_t WARP_SIZE = 32; +constexpr int BLOCK_M = 128; +constexpr int BLOCK_K = 128; +using ThrLayout = Layout, Stride<_8, _1>>; +using ValLayout = Layout>; +using SfR2SThrLayout = Layout, Stride<_4, _1>>; +using SfR2SValLayout = Layout>; +using ScaleFactorTileLayout = + Layout, _4>, Stride, _1>>; + +// Fast reciprocal. +inline __device__ float reciprocal_approximate_ftz(float a) { + float b; + asm volatile("rcp.approx.ftz.f32 %0, %1;\n" : "=f"(b) : "f"(a)); + return b; +} + +// Some code references TRT-LLM: +// https://github.com/NVIDIA/TensorRT-LLM/blob/main/cpp/tensorrt_llm/kernels/quantization.cuh +template +__inline__ __device__ uint8_t cvt_warp_fp16_to_mxfp8(FragmentS& fragment_s, + FragmentD& fragment_d) { + using FragmentSLayout = typename FragmentS::layout_type; + using FragmentDLayout = typename FragmentD::layout_type; + FragmentSLayout fragment_s_layout; + FragmentDLayout fragment_d_layout; + static_assert(is_static::value && + size(fragment_s_layout) == 16); + static_assert(is_static::value && + size(fragment_d_layout) == 16); + + constexpr int eles_per_thr = 16; + using ValType = typename FragmentS::element_type; + using VecType = std::conditional_t, + __nv_bfloat162, __half2>; + VecType vec[8]; + // Assign vals + vec[0].x = fragment_s(Int<0>{}); + vec[0].y = fragment_s(Int<1>{}); + vec[1].x = fragment_s(Int<2>{}); + vec[1].y = fragment_s(Int<3>{}); + vec[2].x = fragment_s(Int<4>{}); + vec[2].y = fragment_s(Int<5>{}); + vec[3].x = fragment_s(Int<6>{}); + vec[3].y = fragment_s(Int<7>{}); + vec[4].x = fragment_s(Int<8>{}); + vec[4].y = fragment_s(Int<9>{}); + vec[5].x = fragment_s(Int<10>{}); + vec[5].y = fragment_s(Int<11>{}); + vec[6].x = fragment_s(Int<12>{}); + vec[6].y = fragment_s(Int<13>{}); + vec[7].x = fragment_s(Int<14>{}); + vec[7].y = fragment_s(Int<15>{}); + + auto local_max = __habs2(vec[0]); + for (int i = 1; i < eles_per_thr / 2; i++) { + local_max = __hmax2(__habs2(vec[i]), local_max); + } + local_max = __hmax2(__shfl_xor_sync(uint32_t(-1), local_max, 1), local_max); + + // Get the final absolute maximum values. + float block_max(0.0f); + if constexpr (std::is_same_v) { + block_max = __bfloat162float(__hmax(local_max.x, local_max.y)); + } else { + block_max = __half2float(__hmax(local_max.x, local_max.y)); + } + // Get the SF (max value of the vector / max value of mxfp8). + float sf_val = block_max * reciprocal_approximate_ftz(448.0f); + // 8 bits representation of the SF. + uint8_t fp8_sf_val; + + __nv_fp8_e8m0 tmp_sf_val; + tmp_sf_val.__x = + __nv_cvt_float_to_e8m0(sf_val, __NV_SATFINITE, cudaRoundPosInf); + sf_val = static_cast(tmp_sf_val); + fp8_sf_val = tmp_sf_val.__x; + // Get the output scale (reciprocal of the SFValue). + float output_scale = + block_max != 0.f ? reciprocal_approximate_ftz(sf_val) : 0.0f; + + // Convert the input to float. + float2 fp2_vals[eles_per_thr / 2]; + +#pragma unroll + for (int i = 0; i < eles_per_thr / 2; i++) { + if constexpr (std::is_same_v) { + fp2_vals[i] = __half22float2(vec[i]); + } else { + fp2_vals[i] = __bfloat1622float2(vec[i]); + } + fp2_vals[i].x *= output_scale; + fp2_vals[i].y *= output_scale; + } + union { + uint8_t bytes[16]; + __nv_fp8x2_e4m3 elts[8]; + } u; + u.elts[0] = __nv_fp8x2_e4m3(fp2_vals[0]); + u.elts[1] = __nv_fp8x2_e4m3(fp2_vals[1]); + u.elts[2] = __nv_fp8x2_e4m3(fp2_vals[2]); + u.elts[3] = __nv_fp8x2_e4m3(fp2_vals[3]); + u.elts[4] = __nv_fp8x2_e4m3(fp2_vals[4]); + u.elts[5] = __nv_fp8x2_e4m3(fp2_vals[5]); + u.elts[6] = __nv_fp8x2_e4m3(fp2_vals[6]); + u.elts[7] = __nv_fp8x2_e4m3(fp2_vals[7]); + fragment_d(Int<0>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[0]); + fragment_d(Int<1>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[1]); + fragment_d(Int<2>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[2]); + fragment_d(Int<3>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[3]); + fragment_d(Int<4>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[4]); + fragment_d(Int<5>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[5]); + fragment_d(Int<6>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[6]); + fragment_d(Int<7>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[7]); + fragment_d(Int<8>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[8]); + fragment_d(Int<9>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[9]); + fragment_d(Int<10>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[10]); + fragment_d(Int<11>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[11]); + fragment_d(Int<12>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[12]); + fragment_d(Int<13>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[13]); + fragment_d(Int<14>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[14]); + fragment_d(Int<15>{}) = cutlass::float_e4m3_t::bitcast(u.bytes[15]); + return fp8_sf_val; +} + +template +__inline__ __device__ void mxfp8_experts_quant_tile( + TensorS& tensor_s, TensorP& tensor_p, TensorD& tensor_d, + TensorSharedSF& tensor_shared_sf, TensorSF& tensor_sf, int m, + TiledCopyG2R& tiled_copy_g2r, TiledCopyR2G& tiled_copy_r2g, + TiledCopyR2S& tiled_copy_r2s) { + static_assert(size(get<0>(typename TensorS::layout_type{})) == 128 && + size(get<1>(typename TensorS::layout_type{})) == 128 && + stride(get<1>(typename TensorS::layout_type{})) == 1); + static_assert(size(get<0>(typename TensorD::layout_type{})) == 128 && + size(get<1>(typename TensorD::layout_type{})) == 128 && + stride(get<1>(typename TensorD::layout_type{})) == 1); + static_assert(size(get<0>(typename TensorP::layout_type{})) == 128 && + size(get<1>(typename TensorP::layout_type{})) == 128); + static_assert(size(get<0>(typename TensorSharedSF::layout_type{})) == 128 && + size(get<1>(typename TensorSharedSF::layout_type{})) == 4); + static_assert(size(get<0>(typename TensorSF::layout_type{})) == 128 && + size(get<1>(typename TensorSF::layout_type{})) == 4); + + using Tiler_MN = typename TiledCopyG2R::Tiler_MN; + auto tiler_mn = Tiler_MN{}; + static_assert(size<0>(tiler_mn) == 16 && size<1>(tiler_mn) == 128); + + auto tiled_tensor_s = tiled_divide(tensor_s, tiler_mn); + auto tiled_tensor_p = tiled_divide(tensor_p, tiler_mn); + auto tiled_tensor_d = tiled_divide(tensor_d, tiler_mn); + static_assert(size<2>(tiled_tensor_s) == 1); + static_assert(size<2>(tiled_tensor_p) == 1); + static_assert(size<2>(tiled_tensor_d) == 1); + auto squeeze_tiled_tensor_s = take<0, 2>(tiled_tensor_s); + auto squeeze_tiled_tensor_p = take<0, 2>(tiled_tensor_p); + auto squeeze_tiled_tensor_d = take<0, 2>(tiled_tensor_d); + + using SF_Tiler_MN = typename TiledCopyR2S::Tiler_MN; + auto sf_tiler_mn = SF_Tiler_MN{}; + static_assert(size<0>(sf_tiler_mn) == 16 && size<1>(sf_tiler_mn) == 4); + + auto tiled_tensor_sf = tiled_divide(tensor_sf, sf_tiler_mn); + auto tiled_tensor_shared_sf = tiled_divide(tensor_shared_sf, sf_tiler_mn); + auto squeeze_tiled_tensor_sf = take<0, 2>(tiled_tensor_sf); + auto squeeze_tiled_tensor_shared_sf = take<0, 2>(tiled_tensor_shared_sf); + + constexpr int tile_loop_count = size<1>(tiled_tensor_s); + constexpr int rows_in_tile = 16; + // We don't need to clear shared memory + // clear(squeeze_tiled_tensor_shared_sf); +#pragma unroll 4 + for (int t = 0; t < tile_loop_count; t++) { + if (t * rows_in_tile >= m) { + break; + } + auto current_copy_tile_s = tensor<0>(squeeze_tiled_tensor_s(_, t)); + auto current_copy_tile_p = tensor<0>(squeeze_tiled_tensor_p(_, t)); + auto current_copy_tile_d = tensor<0>(squeeze_tiled_tensor_d(_, t)); + auto current_copy_tile_sf = tensor<0>(squeeze_tiled_tensor_sf(_, t)); + auto current_copy_tile_shared_sf = + tensor<0>(squeeze_tiled_tensor_shared_sf(_, t)); + + // Global to Register copy + auto thr_copy_g2r = tiled_copy_g2r.get_thread_slice(threadIdx.x); + auto thr_tile_g2r_s = thr_copy_g2r.partition_S(current_copy_tile_s); + auto thr_tile_g2r_p = thr_copy_g2r.partition_S(current_copy_tile_p); + auto input_fragment = make_fragment_like(thr_tile_g2r_s); + + // Register to Global copy + auto thr_copy_r2g = tiled_copy_r2g.get_thread_slice(threadIdx.x); + auto thr_tile_r2g_d = thr_copy_r2g.partition_D(current_copy_tile_d); + auto thr_tile_r2g_p = thr_copy_r2g.partition_D(current_copy_tile_p); + auto output_fragment = make_fragment_like(thr_tile_r2g_d); + + // Register to Shared copy + auto thr_copy_r2s = tiled_copy_r2s.get_thread_slice(threadIdx.x / 2); + auto thr_tile_r2s_shared_sf = + thr_copy_r2s.partition_D(current_copy_tile_shared_sf); + auto shared_sf_fragment = make_fragment_like(thr_tile_r2s_shared_sf); + + // CopyG2R & convert & CopyR2G + copy_if(tiled_copy_g2r, thr_tile_g2r_p, thr_tile_g2r_s, input_fragment); + uint8_t fp8_sf_val = + cvt_warp_fp16_to_mxfp8(input_fragment, output_fragment); + copy_if(tiled_copy_r2g, thr_tile_r2g_p, output_fragment, thr_tile_r2g_d); + shared_sf_fragment[0] = fp8_sf_val; + + // Before first copy r2s, clear shared memory and wait previous group + if (t == 0 && threadIdx.x == 0) { + // Wait for the group to have completed reading from shared memory. + cuda::ptx::cp_async_bulk_wait_group_read(cuda::ptx::n32_t<0>()); + } + __syncthreads(); + + if (threadIdx.x % 2 == 0) { + copy(tiled_copy_r2s, shared_sf_fragment, thr_tile_r2s_shared_sf); + } + __syncthreads(); + } + + // Wait for shared memory writes to be visible to TMA engine. + cuda::ptx::fence_proxy_async(cuda::ptx::space_shared); // b) + __syncthreads(); + + if (threadIdx.x == 0) { + cuda::ptx::cp_async_bulk(cuda::ptx::space_global, cuda::ptx::space_shared, + squeeze_tiled_tensor_sf.data().get(), + squeeze_tiled_tensor_shared_sf.data().get(), 512); + // Wait for TMA transfer to have finished reading shared memory. + // Create a "bulk async-group" out of the previous bulk copy operation. + cuda::ptx::cp_async_bulk_commit_group(); + } + __syncthreads(); +} + +template +__global__ void mxfp8_experts_quant_kernel( + const T_IN* input, const int* problem_sizes, const int* expert_offsets, + const int* blockscale_offsets, cutlass::float_e4m3_t* quant_output, + uint8_t* scale_factor, int groups, TiledCopyG2R tiled_copy_g2r, + TiledCopyR2G tiled_copy_r2g, TiledCopyR2S tiled_copy_r2s) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 1000 + __shared__ __align__(512) uint8_t shared_memory[512]; + ScaleFactorTileLayout scale_factor_tile_layout{}; + auto scale_factor_shared = + make_tensor(make_smem_ptr(shared_memory), + scale_factor_tile_layout); // ((_32,_4), _4):((_16,_4), _1) + // TODO: Transform Groupwise Schedule into a more efficient Schedule + for (int g = 0; g < groups; g++) { + int m = problem_sizes[g * 3 + 0]; + int k = problem_sizes[g * 3 + 2]; + int64_t expert_offset = static_cast(expert_offsets[g]); + int64_t blockscale_offset = static_cast(blockscale_offsets[g]); + + auto input_tensor = make_tensor( + make_gmem_ptr(input + expert_offset * k), + make_layout(make_shape(m, k), + LayoutRight{})); // (M, K):(K, 1) half_t/bfloat16_t + + auto quant_output_tensor = make_tensor( + make_gmem_ptr(quant_output + expert_offset * k), + make_layout(make_shape(m, k), + LayoutRight{})); // (M, K):(K, 1) cutlass::float_e4m3_t + + auto scale_factor_shape = make_shape(ceil_div(m, 128) * 128, k / 32); + auto scale_factor_layout = tile_to_shape(scale_factor_tile_layout, + scale_factor_shape, LayoutRight{}); + // layout<0>(layout<0>(scale_factor_layout)) (_32,_4):(_16,_4) -- static + // layout<1>(layout<0>(scale_factor_layout)) M_align_128 / 128 -- dynamic + // shape dynamic stride layout<0>(layout<1>(scale_factor_layout)) _4:_1 -- + // static layout<1>(layout<1>(scale_factor_layout)) (K / 32) / 4 : _512 -- + // dynamic shape static stride + + // Reshape to zipped layout for 1D indexing + auto zipped_scale_factor_layout = make_layout( + make_layout(layout<0>(layout<0>(scale_factor_layout)), + layout<0>(layout<1>(scale_factor_layout))), + make_layout( + layout<1>(layout<0>(scale_factor_layout)), + layout<1>(layout<1>( + scale_factor_layout)))); // (((_32,_4),_4),(M_align_128 / + // 128,(K / 32) / + // 4)):(((_16,_4),_1),(?,_512)) + + auto scale_factor_tensor = + make_tensor(make_gmem_ptr(scale_factor + blockscale_offset * (k / 32)), + zipped_scale_factor_layout); + + // Used for cases where M is not divisible by 128 (most scenarios). + auto input_shape = shape(input_tensor); // (M, K):(K, 1) + auto identity_tensor = make_identity_tensor(input_shape); + auto predict_tensor = cute::lazy::transform( + identity_tensor, [&](auto c) { return elem_less(c, input_shape); }); + + // (_128, _128) + auto tiler = make_shape(Int{}, Int{}); + + auto tiled_input_tensor = zipped_divide( + input_tensor, tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128))) + auto tiled_quant_output_tensor = + zipped_divide(quant_output_tensor, + tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128))) + auto tiled_predict_tensor = zipped_divide( + predict_tensor, tiler); // ((128, 128), (cdiv(M, 128), cdiv(K, 128))) + + auto total_tiles = + size<1>(tiled_input_tensor); // cdiv(M, 128) * cdiv(K, 128) + decltype(total_tiles) blk_offset = blockIdx.x; + while (blk_offset < total_tiles) { + auto current_input_tile = tensor<0>(tiled_input_tensor(_, blk_offset)); + auto current_quant_output_tile = + tensor<0>(tiled_quant_output_tensor(_, blk_offset)); + auto current_predict_tile = + tensor<0>(tiled_predict_tensor(_, blk_offset)); + auto current_scale_factor_tile = + tensor<0>(scale_factor_tensor(_, blk_offset)); + + mxfp8_experts_quant_tile< + decltype(current_input_tile), decltype(current_predict_tile), + decltype(current_quant_output_tile), decltype(scale_factor_shared), + decltype(current_scale_factor_tile), TiledCopyG2R, TiledCopyR2G, + TiledCopyR2S>(current_input_tile, current_predict_tile, + current_quant_output_tile, scale_factor_shared, + current_scale_factor_tile, m, tiled_copy_g2r, + tiled_copy_r2g, tiled_copy_r2s); + blk_offset += gridDim.x; + } + } +#endif +} + +template +void launch_mxfp8_experts_quant(const torch::Tensor& input, + const torch::Tensor& problem_sizes, + const torch::Tensor& expert_offsets, + const torch::Tensor& blockscale_offsets, + torch::Tensor& quant_output, + torch::Tensor& scale_factor) { + ThrLayout thr_layout{}; + ValLayout val_layout{}; + SfR2SThrLayout r2s_thr_layout{}; + SfR2SValLayout r2s_val_layout{}; + + using CopyOpG2R = + UniversalCopy>; + using CopyAtomG2R = cute::Copy_Atom; + auto tiled_copy_g2r = cute::make_tiled_copy( + CopyAtomG2R{}, thr_layout, val_layout); // Tiler_MN: (16, 128) + + using CopyOpR2G = UniversalCopy< + cutlass::AlignedArray>; + using CopyAtomR2G = cute::Copy_Atom; + auto tiled_copy_r2g = cute::make_tiled_copy( + CopyAtomR2G{}, thr_layout, val_layout); // Tiler_MN: (16, 128) + + using CopyOpR2S = + UniversalCopy>; + using CopyAtomR2S = cute::Copy_Atom; + auto tiled_copy_r2s = cute::make_tiled_copy( + CopyAtomR2S{}, r2s_thr_layout, r2s_val_layout); // Tiler_MN: (16, 4) + + int max_active_blocks_per_sm = -1; + AT_CUDA_CHECK(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks_per_sm, + mxfp8_experts_quant_kernel, + THREAD_BLOCK_SIZE, 0)); + + dim3 grid(at::cuda::getCurrentDeviceProperties()->multiProcessorCount * + max_active_blocks_per_sm, + 1, 1); + dim3 block(THREAD_BLOCK_SIZE, 1, 1); + int num_experts = (int)problem_sizes.size(0); + auto stream = at::cuda::getCurrentCUDAStream(); + mxfp8_experts_quant_kernel + <<>>( + reinterpret_cast(input.data_ptr()), + reinterpret_cast(problem_sizes.data_ptr()), + reinterpret_cast(expert_offsets.data_ptr()), + reinterpret_cast(blockscale_offsets.data_ptr()), + reinterpret_cast(quant_output.data_ptr()), + reinterpret_cast(scale_factor.data_ptr()), num_experts, + tiled_copy_g2r, tiled_copy_r2g, tiled_copy_r2s); +} + +} // namespace expert_specialization \ No newline at end of file diff --git a/csrc/moe/router_gemm.cu b/csrc/moe/router_gemm.cu new file mode 100644 index 000000000000..a939f8846ff1 --- /dev/null +++ b/csrc/moe/router_gemm.cu @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright contributors to the vLLM project + +// bf16 x bf16 -> fp32 router GEMM via cuBLAS. +// Uses CUBLAS_COMPUTE_32F so bf16 operands accumulate into fp32, +// matching TRT-LLM's cuBLAS fallback behaviour in dsv3RouterGemmOp. + +#include +#include +#include + +// cuBLAS column-major math for row-major PyTorch tensors: +// weight[N,K]_row lda=K -> cuBLAS sees (K,N) col-major; CUBLAS_OP_T -> +// (N,K) input[M,K]_row ldb=K -> cuBLAS sees (K,M) col-major; CUBLAS_OP_N +// -> (K,M) out[M,N]_row ldc=N -> cuBLAS sees (N,M) col-major (written as +// output^T) +// cuBLAS: C(N,M) = weight(N,K) @ input(K,M) => C^T = output[M,N] +// params: m=N, n=M, k=K, lda=K (weight), ldb=K (input), ldc=N (output) + +torch::Tensor router_gemm_bf16_fp32(torch::Tensor const& input, + torch::Tensor const& weight) { + TORCH_CHECK(input.dtype() == torch::kBFloat16, + "router_gemm_bf16_fp32: input must be bfloat16"); + TORCH_CHECK(weight.dtype() == torch::kBFloat16, + "router_gemm_bf16_fp32: weight must be bfloat16"); + TORCH_CHECK(input.dim() == 2 && weight.dim() == 2, + "router_gemm_bf16_fp32: input and weight must be 2-D"); + TORCH_CHECK(input.size(1) == weight.size(1), + "router_gemm_bf16_fp32: inner dimensions must match"); + + int64_t const M = input.size(0); + int64_t const N = weight.size(0); + int64_t const K = input.size(1); + + auto out = torch::empty({M, N}, input.options().dtype(torch::kFloat32)); + + cublasHandle_t handle = at::cuda::getCurrentCUDABlasHandle(); + TORCH_CUDABLAS_CHECK( + cublasSetStream(handle, at::cuda::getCurrentCUDAStream())); + + float const alpha = 1.0f; + float const beta = 0.0f; + + TORCH_CUDABLAS_CHECK(cublasGemmEx( + handle, CUBLAS_OP_T, CUBLAS_OP_N, static_cast(N), + static_cast(M), static_cast(K), &alpha, weight.data_ptr(), + CUDA_R_16BF, static_cast(K), input.data_ptr(), CUDA_R_16BF, + static_cast(K), &beta, out.data_ptr(), CUDA_R_32F, + static_cast(N), CUBLAS_COMPUTE_32F, CUBLAS_GEMM_DEFAULT)); + + return out; +} diff --git a/csrc/moe/torch_bindings.cpp b/csrc/moe/torch_bindings.cpp index 438599451452..7b627a6f8760 100644 --- a/csrc/moe/torch_bindings.cpp +++ b/csrc/moe/torch_bindings.cpp @@ -125,6 +125,10 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, m) { "Tensor)"); m.impl("grouped_topk", torch::kCUDA, &grouped_topk); + // cuBLAS bf16 x bf16 -> fp32 router GEMM (fallback for non-SM90 / batch > 16) + m.def("router_gemm_bf16_fp32(Tensor input, Tensor weight) -> Tensor"); + m.impl("router_gemm_bf16_fp32", torch::kCUDA, &router_gemm_bf16_fp32); + // DeepSeek V3 optimized router GEMM for SM90+ m.def("dsv3_router_gemm(Tensor! output, Tensor mat_a, Tensor mat_b) -> ()"); // conditionally compiled so impl registration is in source file diff --git a/csrc/ops.h b/csrc/ops.h index 5e2b475fa8c1..921d6484d2d3 100644 --- a/csrc/ops.h +++ b/csrc/ops.h @@ -269,13 +269,13 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets( torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, const int64_t n, const int64_t k, const bool swap_ab); -void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, - torch::Tensor& problem_sizes2, - const torch::Tensor& expert_num_tokens, - const int64_t num_local_experts, - const int64_t padded_m, const int64_t n, - const int64_t k); +void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + const torch::Tensor& expert_num_tokens, + const int64_t num_local_experts, + const int64_t padded_m, const int64_t n, + const int64_t k); void cutlass_scaled_mm_azp(torch::Tensor& out, torch::Tensor const& a, torch::Tensor const& b, @@ -371,7 +371,9 @@ void selective_scan_fwd( const torch::Tensor& ssm_states, int64_t pad_slot_id, int64_t block_size, const std::optional& block_idx_first_scheduled_token, const std::optional& block_idx_last_scheduled_token, - const std::optional& initial_state_idx); + const std::optional& initial_state_idx, + const std::optional& cu_chunk_seqlen, + const std::optional& last_chunk_indices); torch::Tensor dynamic_4bit_int_moe_cpu( torch::Tensor x, torch::Tensor topk_ids, torch::Tensor topk_weights, diff --git a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu index d0264c4d154c..3539096c9feb 100644 --- a/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu +++ b/csrc/quantization/fp4/activation_nvfp4_quant_fusion_kernels.cu @@ -39,12 +39,12 @@ namespace vllm { template __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) silu_mul_cvt_fp16_to_fp4(int32_t numRows, int32_t numCols, - int32_t num_padded_cols, + int32_t num_packed_cols, Type const* __restrict__ in, float const* __restrict__ SFScale, uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) { - using PackedVec = vllm::PackedVec; + using PackedVec = vllm::PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, @@ -63,7 +63,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) // Input tensor row/col loops. for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { - if (colIdx < num_padded_cols) { + if (colIdx < num_packed_cols) { PackedVec in_vec; PackedVec in_vec2; int64_t inOffset = @@ -73,19 +73,19 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) bool valid = (rowIdx < numRows) && (elem_idx < numCols); if constexpr (CVT_FP4_PACK16) { - ld256_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 8], - valid); - ld256_or_zero_cg_u32( - in_vec2, &reinterpret_cast(in)[inOffset2 * 8], - valid); + ld256_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 8], + valid); + ld256_cg_or_zero(reinterpret_cast(in_vec2), + &reinterpret_cast(in)[inOffset2 * 8], + valid); } else { - ld128_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 4], - valid); - ld128_or_zero_cg_u32( - in_vec2, &reinterpret_cast(in)[inOffset2 * 4], - valid); + ld128_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 4], + valid); + ld128_cg_or_zero(reinterpret_cast(in_vec2), + &reinterpret_cast(in)[inOffset2 * 4], + valid); } // Compute silu and mul @@ -107,7 +107,9 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) (uint64_t(out_val.hi) << 32) | uint64_t(out_val.lo); reinterpret_cast(out)[outOffset >> 1] = packed64; } else { - out[inOffset] = out_val; + int64_t outOffset = + rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; + out[outOffset] = out_val; } } } @@ -140,9 +142,9 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] int const numBlocksPerSM = vllm_runtime_blocks_per_sm(static_cast(block.x)); - int sf_n_unpadded = int(n / CVT_FP4_SF_VEC_SIZE); + int num_packed_cols = int(n / CVT_FP4_ELTS_PER_THREAD); - int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); + int grid_y = vllm::div_round_up(num_packed_cols, static_cast(block.x)); int grid_x = std::min( int(m), std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); dim3 grid(grid_x, grid_y); @@ -152,7 +154,7 @@ void silu_and_mul_nvfp4_quant_sm1xxa(torch::Tensor& output, // [..., d] using cuda_type = vllm::CUDATypeConverter::Type; auto input_ptr = static_cast(input.data_ptr()); vllm::silu_mul_cvt_fp16_to_fp4<<>>( - m, n, sf_n_unpadded, input_ptr, input_sf_ptr, + m, n, num_packed_cols, input_ptr, input_sf_ptr, reinterpret_cast(output_ptr), reinterpret_cast(sf_out)); }); diff --git a/csrc/quantization/fp4/nvfp4_experts_quant.cu b/csrc/quantization/fp4/nvfp4_experts_quant.cu index 32685c201102..3162b6cdb8a9 100644 --- a/csrc/quantization/fp4/nvfp4_experts_quant.cu +++ b/csrc/quantization/fp4/nvfp4_experts_quant.cu @@ -43,7 +43,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) uint32_t* input_offset_by_experts, uint32_t* output_scale_offset_by_experts, int n_experts, bool low_latency) { - using PackedVec = PackedVec; + using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, @@ -155,7 +155,7 @@ __global__ void __launch_bounds__(1024, VLLM_BLOCKS_PER_SM(1024)) float const* SFScale, uint32_t* out, uint32_t* SFout, uint32_t* input_offset_by_experts, uint32_t* output_scale_offset_by_experts, int n_experts) { - using PackedVec = PackedVec; + using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); static_assert(sizeof(PackedVec) == sizeof(Type) * CVT_FP4_ELTS_PER_THREAD, diff --git a/csrc/quantization/fp4/nvfp4_quant_kernels.cu b/csrc/quantization/fp4/nvfp4_quant_kernels.cu index c27fb69d44be..773047c22500 100644 --- a/csrc/quantization/fp4/nvfp4_quant_kernels.cu +++ b/csrc/quantization/fp4/nvfp4_quant_kernels.cu @@ -42,7 +42,7 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) Type const* __restrict__ in, float const* __restrict__ SFScale, uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) { - using PackedVec = vllm::PackedVec; + using PackedVec = vllm::PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -71,13 +71,13 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) // If we are outside valid rows OR outside valid columns -> Use Zeros bool valid = (rowIdx < numRows) && (elem_idx < numCols); if constexpr (CVT_FP4_PACK16) { - ld256_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 8], - valid); + ld256_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 8], + valid); } else { - ld128_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 4], - valid); + ld128_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 4], + valid); } auto sf_out = @@ -109,11 +109,12 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) template __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) cvt_fp16_to_fp4_sf_major(int32_t numRows, int32_t numCols, - int32_t sf_n_unpadded, Type const* __restrict__ in, + int32_t sf_n_unpadded, int32_t num_packed_cols, + Type const* __restrict__ in, float const* __restrict__ SFScale, uint32_t* __restrict__ out, uint32_t* __restrict__ SFout) { - using PackedVec = PackedVec; + using PackedVec = PackedVec; static constexpr int CVT_FP4_NUM_THREADS_PER_SF = (CVT_FP4_SF_VEC_SIZE / CVT_FP4_ELTS_PER_THREAD); @@ -131,20 +132,20 @@ __global__ void __launch_bounds__(512, VLLM_BLOCKS_PER_SM(512)) // Iterate over all rows and cols including padded ones - // ensures we visit every single scale factor address to initialize it. for (int rowIdx = blockIdx.x; rowIdx < numRows; rowIdx += gridDim.x) { - if (colIdx < sf_n_unpadded) { + if (colIdx < num_packed_cols) { PackedVec in_vec; int64_t inOffset = rowIdx * (numCols / CVT_FP4_ELTS_PER_THREAD) + colIdx; // If we are outside valid rows OR outside valid columns -> Use Zeros bool valid = (rowIdx < numRows) && (elem_idx < numCols); if constexpr (CVT_FP4_PACK16) { - ld256_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 8], - valid); + ld256_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 8], + valid); } else { - ld128_or_zero_cg_u32( - in_vec, &reinterpret_cast(in)[inOffset * 4], - valid); + ld128_cg_or_zero(reinterpret_cast(in_vec), + &reinterpret_cast(in)[inOffset * 4], + valid); } auto sf_out = @@ -222,7 +223,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, reinterpret_cast(sf_out)); }); } else { - int grid_y = vllm::div_round_up(sf_n_unpadded, static_cast(block.x)); + int num_packed_cols = n / CVT_FP4_ELTS_PER_THREAD; + int grid_y = vllm::div_round_up(num_packed_cols, static_cast(block.x)); int grid_x = std::min( m, std::max(1, (multiProcessorCount * numBlocksPerSM) / grid_y)); dim3 grid(grid_x, grid_y); @@ -232,8 +234,8 @@ void scaled_fp4_quant_sm1xxa(torch::Tensor const& output, auto input_ptr = static_cast(input.data_ptr()); // NOTE: We don't support e8m0 scales at this moment. vllm::cvt_fp16_to_fp4_sf_major - <<>>(m, n, sf_n_unpadded, input_ptr, - input_sf_ptr, + <<>>(m, n, sf_n_unpadded, num_packed_cols, + input_ptr, input_sf_ptr, reinterpret_cast(output_ptr), reinterpret_cast(sf_out)); }); diff --git a/csrc/quantization/fp4/nvfp4_utils.cuh b/csrc/quantization/fp4/nvfp4_utils.cuh index 3e7adb9e2931..c1df1860c1a1 100644 --- a/csrc/quantization/fp4/nvfp4_utils.cuh +++ b/csrc/quantization/fp4/nvfp4_utils.cuh @@ -19,8 +19,10 @@ #include #include -#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \ - defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) +#include "../../cuda_vec_utils.cuh" + +#if defined(NVFP4_ENABLE_ELTS16) && defined(CUDA_VERSION) && \ + CUDA_VERSION >= 12090 #define ELTS_PER_THREAD 16 constexpr int CVT_FP4_ELTS_PER_THREAD = 16; constexpr bool CVT_FP4_PACK16 = true; @@ -34,68 +36,6 @@ constexpr int CVT_FP4_SF_VEC_SIZE = 16; namespace vllm { -// Convert PyTorch cpp type to CUDA type -template -struct CUDATypeConverter { - using Type = T; -}; - -template <> -struct CUDATypeConverter { - using Type = half; -}; - -template <> -struct CUDATypeConverter { - using Type = __nv_bfloat16; -}; - -// Get type2 from type or vice versa (applied to half and bfloat16) -template -struct TypeConverter { - using Type = half2; -}; // keep for generality - -template <> -struct TypeConverter { - using Type = half; -}; - -template <> -struct TypeConverter { - using Type = half2; -}; - -template <> -struct TypeConverter<__nv_bfloat162> { - using Type = __nv_bfloat16; -}; - -template <> -struct TypeConverter<__nv_bfloat16> { - using Type = __nv_bfloat162; -}; - -#if (defined(NVFP4_ENABLE_ELTS16) && (CUDART_VERSION >= 12090) && \ - defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) -// Define a 32 bytes packed data type. -template -struct alignas(32) PackedVec { - typename TypeConverter::Type elts[8]; -}; -#else -// Define a 16 bytes packed data type. -template -struct alignas(16) PackedVec { - typename TypeConverter::Type elts[4]; -}; -#endif - -template <> -struct PackedVec<__nv_fp8_e4m3> { - __nv_fp8x2_e4m3 elts[8]; -}; - template __host__ __device__ inline Int round_up(Int x, Int y) { static_assert(std::is_integral_v, @@ -208,56 +148,6 @@ __device__ __forceinline__ float reciprocal_approximate_ftz(float a) { return b; } -template -__device__ __forceinline__ void ld128_or_zero_cg_u32(PackedVec& out, - const void* ptr, - bool pred) { - uint32_t r0, r1, r2, r3; - - asm volatile( - "{\n" - " .reg .pred pr;\n" - " setp.ne.u32 pr, %4, 0;\n" - " mov.u32 %0, 0;\n" - " mov.u32 %1, 0;\n" - " mov.u32 %2, 0;\n" - " mov.u32 %3, 0;\n" - " @pr ld.global.cg.v4.u32 {%0,%1,%2,%3}, [%5];\n" - "}\n" - : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3) - : "r"((int)pred), "l"(ptr)); - - *reinterpret_cast(&out) = uint4{r0, r1, r2, r3}; -} - -template -__device__ __forceinline__ void ld256_or_zero_cg_u32(PackedVec& out, - const void* ptr, - bool pred) { - uint32_t r0, r1, r2, r3, r4, r5, r6, r7; - - asm volatile( - "{\n" - " .reg .pred pr;\n" - " setp.ne.u32 pr, %8, 0;\n" - " mov.u32 %0, 0;\n" - " mov.u32 %1, 0;\n" - " mov.u32 %2, 0;\n" - " mov.u32 %3, 0;\n" - " mov.u32 %4, 0;\n" - " mov.u32 %5, 0;\n" - " mov.u32 %6, 0;\n" - " mov.u32 %7, 0;\n" - " @pr ld.global.cg.v8.u32 {%0,%1,%2,%3,%4,%5,%6,%7}, [%9];\n" - "}\n" - : "=r"(r0), "=r"(r1), "=r"(r2), "=r"(r3), "=r"(r4), "=r"(r5), "=r"(r6), - "=r"(r7) - : "r"((int)pred), "l"(ptr)); - - reinterpret_cast(&out)[0] = uint4{r0, r1, r2, r3}; - reinterpret_cast(&out)[1] = uint4{r4, r5, r6, r7}; -} - // Compute SF output offset for swizzled tensor core layout. // SF layout: [numMTiles, numKTiles, 32, 4, 4] // Caller must precompute: numKTiles = (numCols + 63) / 64 @@ -315,8 +205,8 @@ __device__ __forceinline__ uint8_t* sf_out_rowmajor_u8(int row, int pack, // Quantizes the provided PackedVec into the uint32_t output template -__device__ __forceinline__ fp4_packed_t -cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { +__device__ __forceinline__ fp4_packed_t cvt_warp_fp16_to_fp4( + PackedVec& vec, float SFScaleVal, uint8_t* SFout) { // Get absolute maximum values among the local 8 values. auto localMax = __habs2(vec.elts[0]); @@ -372,11 +262,7 @@ cvt_warp_fp16_to_fp4(PackedVec& vec, float SFScaleVal, uint8_t* SFout) { #pragma unroll for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) { - if constexpr (std::is_same_v) { - fp2Vals[i] = __half22float2(vec.elts[i]); - } else { - fp2Vals[i] = __bfloat1622float2(vec.elts[i]); - } + fp2Vals[i] = cast_to_float2(vec.elts[i]); fp2Vals[i].x *= outputScale; fp2Vals[i].y *= outputScale; } @@ -395,22 +281,19 @@ __device__ __forceinline__ float2 silu2(float2 x) { } template -__inline__ __device__ PackedVec compute_silu_mul( - const PackedVec& x_vec, const PackedVec& y_vec) { - PackedVec result; +__inline__ __device__ PackedVec compute_silu_mul( + const PackedVec& x_vec, + const PackedVec& y_vec) { + PackedVec result; #pragma unroll for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; ++i) { // silu_mul in float32 - if constexpr (std::is_same_v) { - float2 silu_vec = silu2(__half22float2(x_vec.elts[i])); - result.elts[i] = __float22half2_rn( - __fmul2_rn(silu_vec, __half22float2(y_vec.elts[i]))); - } else { - float2 silu_vec = silu2(__bfloat1622float2(x_vec.elts[i])); - result.elts[i] = __float22bfloat162_rn( - __fmul2_rn(silu_vec, __bfloat1622float2(y_vec.elts[i]))); - } + using packed_t = typename PackedTypeConverter::Type; + float2 silu_vec = silu2(cast_to_float2(x_vec.elts[i])); + float2 y_f2 = cast_to_float2(y_vec.elts[i]); + result.elts[i] = cast_to_packed( + make_float2(silu_vec.x * y_f2.x, silu_vec.y * y_f2.y)); } return result; } diff --git a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu index eae500cb6325..41cf170a2431 100644 --- a/csrc/quantization/w8a8/cutlass/moe/moe_data.cu +++ b/csrc/quantization/w8a8/cutlass/moe/moe_data.cu @@ -263,12 +263,10 @@ void get_cutlass_moe_mm_data_caller( } template -__global__ void compute_pplx_data(int32_t* expert_offsets, - int32_t* problem_sizes1, - int32_t* problem_sizes2, - const int32_t* __restrict__ expert_num_tokens, - const int padded_m, const int n, - const int k) { +__global__ void compute_batched_moe_data( + int32_t* expert_offsets, int32_t* problem_sizes1, int32_t* problem_sizes2, + const int32_t* __restrict__ expert_num_tokens, const int padded_m, + const int n, const int k) { int expert_idx = threadIdx.x; expert_offsets[expert_idx] = expert_idx * padded_m; @@ -289,24 +287,22 @@ __global__ void compute_pplx_data(int32_t* expert_offsets, } } -void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, - torch::Tensor& problem_sizes2, - const torch::Tensor& expert_num_tokens, - const int64_t num_local_experts, - const int64_t padded_m, - const int64_t n, const int64_t k) { +void get_cutlass_batched_moe_mm_data_caller( + torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens, + const int64_t num_local_experts, const int64_t padded_m, const int64_t n, + const int64_t k) { auto stream = at::cuda::getCurrentCUDAStream(expert_offsets.device().index()); if (num_local_experts * padded_m > SWAP_AB_THRESHOLD) { - compute_pplx_data<<<1, num_local_experts, 0, stream>>>( + compute_batched_moe_data<<<1, num_local_experts, 0, stream>>>( static_cast(expert_offsets.data_ptr()), static_cast(problem_sizes1.data_ptr()), static_cast(problem_sizes2.data_ptr()), static_cast(expert_num_tokens.data_ptr()), padded_m, n, k); } else { - compute_pplx_data<<<1, num_local_experts, 0, stream>>>( + compute_batched_moe_data<<<1, num_local_experts, 0, stream>>>( static_cast(expert_offsets.data_ptr()), static_cast(problem_sizes1.data_ptr()), static_cast(problem_sizes2.data_ptr()), diff --git a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu index 82ccc19608cb..d6e82f1db9fa 100644 --- a/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu +++ b/csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu @@ -82,13 +82,11 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets_caller( torch::Tensor& problem_sizes1, torch::Tensor& problem_sizes2, const int64_t n, const int64_t k, const bool swap_ab); -void get_cutlass_pplx_moe_mm_data_caller(torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, - torch::Tensor& problem_sizes2, - const torch::Tensor& expert_num_tokens, - const int64_t num_local_experts, - const int64_t padded_m, - const int64_t n, const int64_t k); +void get_cutlass_batched_moe_mm_data_caller( + torch::Tensor& expert_offsets, torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, const torch::Tensor& expert_num_tokens, + const int64_t num_local_experts, const int64_t padded_m, const int64_t n, + const int64_t k); #endif void cutlass_scaled_mm_azp_sm75(torch::Tensor& c, torch::Tensor const& a, @@ -319,29 +317,30 @@ void get_cutlass_moe_mm_problem_sizes_from_expert_offsets( version_num, ". Required capability: 90, 100, or 120"); } -void get_cutlass_pplx_moe_mm_data(torch::Tensor& expert_offsets, - torch::Tensor& problem_sizes1, - torch::Tensor& problem_sizes2, - const torch::Tensor& expert_num_tokens, - const int64_t num_local_experts, - const int64_t padded_m, const int64_t n, - const int64_t k) { +void get_cutlass_batched_moe_mm_data(torch::Tensor& expert_offsets, + torch::Tensor& problem_sizes1, + torch::Tensor& problem_sizes2, + const torch::Tensor& expert_num_tokens, + const int64_t num_local_experts, + const int64_t padded_m, const int64_t n, + const int64_t k) { // This function currently gets compiled only if we have a valid cutlass moe // mm to run it for. int32_t version_num = get_sm_version_num(); #if (defined ENABLE_CUTLASS_MOE_SM90 && ENABLE_CUTLASS_MOE_SM90) || \ (defined ENABLE_CUTLASS_MOE_SM100 && ENABLE_CUTLASS_MOE_SM100) || \ (defined ENABLE_CUTLASS_MOE_SM120 && ENABLE_CUTLASS_MOE_SM120) - get_cutlass_pplx_moe_mm_data_caller(expert_offsets, problem_sizes1, - problem_sizes2, expert_num_tokens, - num_local_experts, padded_m, n, k); + get_cutlass_batched_moe_mm_data_caller(expert_offsets, problem_sizes1, + problem_sizes2, expert_num_tokens, + num_local_experts, padded_m, n, k); return; #endif - TORCH_CHECK_NOT_IMPLEMENTED( - false, - "No compiled get_cutlass_pplx_moe_mm_data: no cutlass_scaled_mm kernel " - "for CUDA device capability: ", - version_num, ". Required capability: 90, 100, or 120"); + TORCH_CHECK_NOT_IMPLEMENTED(false, + "No compiled get_cutlass_batched_moe_mm_data: no " + "cutlass_scaled_mm kernel " + "for CUDA device capability: ", + version_num, + ". Required capability: 90, 100, or 120"); } void cutlass_scaled_mm_azp(torch::Tensor& c, torch::Tensor const& a, diff --git a/csrc/rocm/skinny_gemms.cu b/csrc/rocm/skinny_gemms.cu index 15ebcc776ad7..19bb324bdcd5 100644 --- a/csrc/rocm/skinny_gemms.cu +++ b/csrc/rocm/skinny_gemms.cu @@ -304,8 +304,9 @@ __device__ inline unsigned int min__(uint32_t a, uint32_t b) { template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_sml_(const int K, const int M, const int Bx, const int By, - const scalar_t* B, const scalar_t* __restrict__ A, + wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap, const int M, + const int Bx, const int By, const scalar_t* B, + const scalar_t* __restrict__ A, const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { constexpr int max_lds_len = LDS_SIZE / 2; @@ -314,7 +315,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #else constexpr bool use_mfma = false; #endif - using scalar8 = __attribute__((__vector_size__((A_CHUNK / 2) * sizeof(float)))) float; using half4 = @@ -346,13 +346,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // - Then the WG will move to another 8 K elements // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- - for (uint32_t k = 0; k < min__(K * N, max_lds_len); - k += THRDS * WvPrGrp * A_CHUNK) { - uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - - if (k_in >= min__(K * N, max_lds_len)) break; - - *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; + k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { + #if defined(__gfx950__) + __builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0); + #else + *((bigType*)(&s[k])) = *((bigType*)(&A[k])); + #endif } __syncthreads(); @@ -360,9 +360,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t m = (blockIdx.x * _WvPrGrp + (threadIdx.y % _WvPrGrp)) * YTILE; - float sum[N][YTILE]; - scalar8 sum4[N][YTILE]; - //---------------------------------------------------- // Each wave works on a single column of weight matrix. // There are 16 waves per WG, and hence, each WG is @@ -386,44 +383,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // YTILE represents how many column of weight matrix // are being worked on by each wave. //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) - if constexpr (!use_mfma) - sum[n][i] = 0; - else - sum4[n][i] = {0, 0, 0, 0}; - - bigType bigA[N][UNRL]; - bigType bigB[YTILE][UNRL]; - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - // for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + float sum[N][YTILE] = {}; + scalar8 sum4[N][YTILE] = {}; + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + bigType bigA[N][UNRL] = {}; + bigType bigB[YTILE][UNRL]; // Fetch the weight matrix from memory! #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - const scalar_t* B_ = &B[(m + 0) * K + k_]; + const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)]; for (int y = 0; y < YTILE; y++) - bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[y * K]))); + bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -432,33 +405,20 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - - // Fetch A activation matrix in interleaved fashion from LDS or memory - for (int n = 0; n < N; n++) { - bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n]))); } } // Do the matrix multiplication in interleaved manner - #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! - #pragma unroll for (uint32_t n = 0; n < N; n++) { - #pragma unroll for (int y = 0; y < YTILE; y++) { if constexpr (!use_mfma) - #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) } else - #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 4; b++) sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); @@ -466,46 +426,44 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) } } } - + __builtin_amdgcn_sched_barrier(0); //---------------------------------------------------- // Final reduction step using shuffle //---------------------------------------------------- if constexpr (!use_mfma) { for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x118, 0xf, 0xf, + 1); // row_shr8 + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x114, 0xf, 0xf, + 1); // row_shr4 + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x112, 0xf, 0xf, + 1); // row_shr2 + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf, + 1); // row_shr1 + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf, + 1); // ROW_BCAST15 + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf, + 1); // ROW_BCAST31 } } if (threadIdx.x == 63) { + scalar_t biases[N][YTILE] = {}; + if (BIAS) + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx]; + } + } for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { + for (int y = 0; y < YTILE; y++) { if constexpr (std::is_same_v) { - if (BIAS) - sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]); + sum[n][y] += __half2float(biases[n][y]); } else if constexpr (std::is_same_v) { - if (BIAS) - sum[n][i] += - __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + sum[n][y] += __bfloat162float(biases[n][y]); } - C[m + i + n * M] = __float2s(sum[n][i]); + C[m + y + n * M] = __float2s(sum[n][y]); } } } @@ -514,45 +472,43 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (int n = 0; n < N; n++) { #pragma unroll for (int y = 0; y < YTILE; y++) { - // float accm1 = 0; - // for (int i=0; i<64; i++) - // accm1 += __shfl(sum4[n][y][i%4], i); + /*float accm1 = 0; + for (int i=0; i<64; i++) + accm1 += __shfl(sum4[n][y][i%4], i); + sum4[n][y][0] = accm1;*/ float accm = sum4[n][y][0]; - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); + accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf, + 1); // row_shl1 + accm += __builtin_amdgcn_mov_dpp(sum4[n][y][2], 0x102, 0xf, 0xf, + 1); // row_shl2 + accm += __builtin_amdgcn_mov_dpp(sum4[n][y][3], 0x103, 0xf, 0xf, + 1); // row_shl3 + accm += __builtin_amdgcn_mov_dpp(accm, 0x104, 0xf, 0xf, + 1); // row_shl4 + accm += __builtin_amdgcn_mov_dpp(accm, 0x108, 0xf, 0xf, + 1); // row_shl8 + accm = __builtin_amdgcn_mov_dpp(accm, 0x11f, 0xf, 0xf, + 1); // row_shr15 + accm += __builtin_amdgcn_mov_dpp(accm, 0x142, 0xf, 0xf, + 1); // ROW_BCAST15 + accm += __builtin_amdgcn_mov_dpp(accm, 0x143, 0xf, 0xf, + 1); // ROW_BCAST31 sum4[n][y][0] = accm; } } if (threadIdx.x == 63) { + scalar_t biases[N][YTILE] = {}; + if (BIAS) + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx]; + } + } for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - if (BIAS) - sum4[n][i][0] += - __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); - C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + for (int y = 0; y < YTILE; y++) { + sum4[n][y][0] += __bfloat162float(biases[n][y]); + C[m + y + n * M] = __float2bfloat16(sum4[n][y][0]); } } } @@ -563,8 +519,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #else // !defined(__HIP__GFX9__) TODO: Add NAVI support template -__global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx, - const int By, const scalar_t* B, +__global__ void wvSplitK_hf_sml_(const int K, const int Kbp, const int Kap, + const int M, const int Bx, const int By, + const scalar_t* B, const scalar_t* __restrict__ A, const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { @@ -577,8 +534,9 @@ __global__ void wvSplitK_hf_sml_(const int K, const int M, const int Bx, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_(const int K, const int M, const int Bx, const int By, - const scalar_t* B, const scalar_t* __restrict__ A, + wvSplitK_hf_(const int K, const int Kbp, const int Kap, const int M, + const int Bx, const int By, const scalar_t* B, + const scalar_t* __restrict__ A, const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { constexpr int max_lds_len = LDS_SIZE / 2; @@ -601,13 +559,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) scalar8 h8; }; - //---------------------------------------------------- - // Reserving 64 KB of LDS to have 1 WG / CU - // Goal is to bring the activation matrix A to the LDS - // and use it across the lifetime of the work group - // TODO: When activation matrix is larger than 64 KB - // then this is not going to work! - //---------------------------------------------------- __shared__ scalar_t s[max_lds_len]; //---------------------------------------------------- @@ -618,12 +569,6 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) commitColumn[i] = 1; } - //---------------------------------------------------- - // Indexing function into the column of weight matrix B - // Algorithm does 64 lane k-splitting / wave and uses - // WG ID and Thread ID to find the index. - //---------------------------------------------------- - // int _WvPrGrp = mindiv(N, CuCount * YTILE, WvPrGrp); uint32_t m = (blockIdx.x * _WvPrGrp + threadIdx.y) * YTILE; // Check whether there will be fragmentation! @@ -636,91 +581,34 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) m = startColumn; } - //---------------------------------------------------- - // Fetch the activation matrix to LDS - // Loop iteration: - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements - // - Each WG will fetch 512 * 16 => 8K elements - // - Then the WG will move to another 8 K elements - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- - for (uint32_t k = 0; k < min__(K * N, max_lds_len); - k += THRDS * WvPrGrp * A_CHUNK) { - uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - - if (k_in >= min__(K * N, max_lds_len)) break; - - *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; + k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { + #if defined(__gfx950__) + __builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0); + #else + *((bigType*)(&s[k])) = *((bigType*)(&A[k])); + #endif } __syncthreads(); if (threadIdx.y >= _WvPrGrp) return; - float sum[N][YTILE]; - scalar8 sum4[N][YTILE]; - - //---------------------------------------------------- - // Each wave works on a single column of weight matrix. - // There are 16 waves per WG, and hence, each WG is - // working on 16 columns of weight matrix. Moreover, - // we tile in column direction by YTILE, so when YTILE=1 - // the above math is right, however, when YTILE=2 then - // each wave will be working on 2 columns and WG will - // be working on 32 columns. - // - // Top level loop that makes WGs persistent! - // - WGs iterates across columns of weight matrix - // - Each wave within WG works on a given column(s) - // - After completing first set of columns, WGs start - // working on the next set of available columns - //---------------------------------------------------- while (m < M) { - //---------------------------------------------------- - // 'sum' accumulates the matrix A x B computation - // split across 64 lanes. - // - // YTILE represents how many column of weight matrix - // are being worked on by each wave. - //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) - if constexpr (!use_mfma) - sum[n][i] = 0; - else - sum4[n][i] = {0, 0, 0, 0}; - - bigType bigA[N][UNRL]; - bigType bigB[YTILE][UNRL]; - //---------------------------------------------------- - // Fetch weight matrix B in interleaved K-split! - // - Each thread (lane) is fetching 8 elements (A_Chunk) - // - Each wave will fetch 64*8=> 512 elements (1024B) - // - YTILE represents the number of column being serviced - // by wave - // - Loop for fetching weight matrix (B) are unrolled - // - // Fetch activation matrix A from LDS - // - Loop for fetching activation matrix (A) are unrolled - // - // Finally, do the matrix multiplication in an unrolled - // fashion. This provides lot of food for compiler - // scheduling. - // - // TODO: Logic below will only work when K is multiple of 8 - //---------------------------------------------------- + float sum[N][YTILE] = {}; + scalar8 sum4[N][YTILE] = {}; + for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + bigType bigA[N][UNRL] = {}; + bigType bigB[YTILE][UNRL]; // Fetch the weight matrix from memory! #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - const scalar_t* B_ = &B[(m + 0) * K + k_]; - for (int b = 0; b < YTILE; b++) - bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K]))); + const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)]; + for (int y = 0; y < YTILE; y++) + bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -729,36 +617,23 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - - // Fetch A activation matrix in interleaved fashion from LDS or memory - for (int n = 0; n < N; n++) { - if (k_ + K * n < max_lds_len) - bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + if (k_ + Kap * n < max_lds_len) + bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n]))); else - bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); + bigA[n][k2] = *((const bigType*)(&(A[k_ + Kap * n]))); } } // Do the matrix multiplication in interleaved manner - #pragma unroll for (uint32_t n = 0; n < N; n++) { - #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! - #pragma unroll for (int y = 0; y < YTILE; y++) { if constexpr (!use_mfma) - #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) } else - #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 4; b++) sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); @@ -773,40 +648,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if constexpr (!use_mfma) { for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x118, 0xf, 0xf, + 1); // row_shr8 + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x114, 0xf, 0xf, + 1); // row_shr4 + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x112, 0xf, 0xf, + 1); // row_shr2 + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf, + 1); // row_shr1 + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf, + 1); // ROW_BCAST15 + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf, + 1); // ROW_BCAST31 } } if (threadIdx.x == 63) { + scalar_t biases[N][YTILE] = {}; + if (BIAS) + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx]; + } + } for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) { + for (int y = 0; y < YTILE; y++) { + if (commitColumn[y]) { if constexpr (std::is_same_v) { - if (BIAS) - sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]); + sum[n][y] += __half2float(biases[n][y]); } else if constexpr (std::is_same_v) { - if (BIAS) - sum[n][i] += - __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + sum[n][y] += __bfloat162float(biases[n][y]); } - C[m + i + n * M] = __float2s(sum[n][i]); + C[m + y + n * M] = __float2s(sum[n][y]); } } } @@ -819,44 +692,39 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // float accm1 = 0; // for (int i=0; i<64; i++) // accm1 += __shfl(sum4[n][y][i%4], i); - float accm = sum4[n][y][0]; - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - + accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf, + 1); // row_shl1 + accm += __builtin_amdgcn_mov_dpp(sum4[n][y][2], 0x102, 0xf, 0xf, + 1); // row_shl2 + accm += __builtin_amdgcn_mov_dpp(sum4[n][y][3], 0x103, 0xf, 0xf, + 1); // row_shl3 + accm += __builtin_amdgcn_mov_dpp(accm, 0x104, 0xf, 0xf, + 1); // row_shl4 + accm += __builtin_amdgcn_mov_dpp(accm, 0x108, 0xf, 0xf, + 1); // row_shl8 + accm = __builtin_amdgcn_mov_dpp(accm, 0x11f, 0xf, 0xf, + 1); // row_shr15 + accm += __builtin_amdgcn_mov_dpp(accm, 0x142, 0xf, 0xf, + 1); // ROW_BCAST15 + accm += __builtin_amdgcn_mov_dpp(accm, 0x143, 0xf, 0xf, + 1); // ROW_BCAST31 sum4[n][y][0] = accm; } } if (threadIdx.x == 63) { + scalar_t biases[N][YTILE] = {}; + if (BIAS) + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx]; + } + } for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) { - if (BIAS) - sum4[n][i][0] += - __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); - C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + for (int y = 0; y < YTILE; y++) { + if (commitColumn[y]) { + sum4[n][y][0] += __bfloat162float(biases[n][y]); + C[m + y + n * M] = __float2bfloat16(sum4[n][y][0]); } } } @@ -880,9 +748,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #else // !defined(__HIP__GFX9__) TODO: Add NAVI support template -__global__ void wvSplitK_hf_(const int K, const int M, const int Bx, - const int By, const scalar_t* B, - const scalar_t* __restrict__ A, +__global__ void wvSplitK_hf_(const int K, const int Kbp, const int Kap, + const int M, const int Bx, const int By, + const scalar_t* B, const scalar_t* __restrict__ A, const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { UNREACHABLE_CODE @@ -894,8 +762,9 @@ __global__ void wvSplitK_hf_(const int K, const int M, const int Bx, template __global__ void __launch_bounds__(WvPrGrp* THRDS) - wvSplitK_hf_big_(const int K, const int M, const int Bx, const int By, - const scalar_t* B, const scalar_t* __restrict__ A, + wvSplitK_hf_big_(const int K, const int Kbp, const int Kap, const int M, + const int Bx, const int By, const scalar_t* B, + const scalar_t* __restrict__ A, const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { constexpr int max_lds_len = LDS_SIZE / 2; @@ -966,13 +835,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) //---------------------------------------------------- #define PCML #ifndef PCML - for (uint32_t k = 0; k < min__(K * N, max_lds_len); - k += THRDS * WvPrGrp * A_CHUNK) { - uint32_t k_in = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - - if (k_in >= min__(K * N, max_lds_len)) break; - - *((bigType*)(&s[k_in])) = *((bigType*)(&A[k_in])); + for (uint32_t k = (threadIdx.y * THRDS + threadIdx.x) * A_CHUNK; + k < min__(Kap * N, max_lds_len); k += THRDS * WvPrGrp * A_CHUNK) { + #if defined(__gfx950__) + __builtin_amdgcn_global_load_lds((int*)(&A[k]), (int*)(&s[k]), 16, 0, 0); + #else + *((bigType*)(&s[k])) = *((bigType*)(&A[k])); + #endif } __syncthreads(); #endif @@ -987,10 +856,7 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) ? kFit : (kFit - kFit % TUC); // round up to multiple of TUC // if (kFit == 0) kFit = TUC; - kFit = min__(kFit, K); - - float sum[N][YTILE]; - scalar8 sum4[N][YTILE]; + kFit = min__(kFit, Kap); //---------------------------------------------------- // Each wave works on a single column of weight matrix. @@ -1021,15 +887,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // YTILE represents how many column of weight matrix // are being worked on by each wave. //---------------------------------------------------- - for (int i = 0; i < YTILE; i++) - for (int n = 0; n < N; n++) - if constexpr (!use_mfma) - sum[n][i] = 0; - else - sum4[n][i] = {0, 0, 0, 0}; - - bigType bigA[N][UNRL]; - bigType bigB[YTILE][UNRL]; + float sum[N][YTILE] = {}; + scalar8 sum4[N][YTILE] = {}; + //---------------------------------------------------- // Fetch weight matrix B in interleaved K-split! // - Each thread (lane) is fetching 8 elements (A_Chunk) @@ -1048,18 +908,26 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // TODO: Logic below will only work when K is multiple of 8 //---------------------------------------------------- for (uint32_t k1 = 0; k1 < K; k1 += THRDS * A_CHUNK * UNRL) { + bigType bigA[N][UNRL] = {}; + bigType bigB[YTILE][UNRL]; + #ifdef PCML if ((k1 == 0) || (k1 == kBase + kFit)) { // load next chunk of A[] to LDS if (k1 != 0) kBase += kFit; __syncthreads(); for (uint32_t k = 0; k < kFit; k += THRDS * _WvPrGrp * A_CHUNK) { uint32_t kOff = k + ((threadIdx.y * THRDS + threadIdx.x) * A_CHUNK); - if (kBase + kOff >= K) break; + if (kBase + kOff >= Kap) break; if (kOff >= kFit) break; for (uint32_t n = 0; n < N; n++) { - uint32_t k_in = kBase + n * K + kOff; + uint32_t k_in = kBase + n * Kap + kOff; uint32_t k_ot = n * kFit + kOff; + #if defined(__gfx950__) + __builtin_amdgcn_global_load_lds((int*)(&A[k_in]), (int*)(&s[k_ot]), + 16, 0, 0); + #else *((bigType*)(&s[k_ot])) = *((bigType*)(&A[k_in])); + #endif } } __syncthreads(); @@ -1072,11 +940,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) for (uint32_t k2 = 0; k2 < UNRL; k2++) { uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - - const scalar_t* B_ = &B[(m + 0) * K + k_]; - for (int b = 0; b < YTILE; b++) - bigB[b][k2].h8 = (loadnt((scalar8*)(&B_[b * K]))); + const scalar_t* B_ = &B[min__(k_, K - A_CHUNK)]; + for (int y = 0; y < YTILE; y++) + bigB[y][k2].h8 = (loadnt((scalar8*)(&B_[min__(y + m, M - 1) * Kbp]))); } // Fetch activation matrix from either just LDS or from both LDS / memory @@ -1085,17 +951,14 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) uint32_t k = k1 + k2 * THRDS * A_CHUNK; uint32_t k_ = k + threadIdx.x * A_CHUNK; if (k_ >= K) break; - - // Fetch A activation matrix in interleaved fashion from LDS or memory - for (int n = 0; n < N; n++) { #ifdef PCML bigA[n][k2] = *((const bigType*)(&(s[k_ - kBase + kFit * n]))); #else - if (k_ + K * n < 32 * 1024) - bigA[n][k2] = *((const bigType*)(&(s[k_ + K * n]))); + if (k_ + Kap * n < max_lds_len) + bigA[n][k2] = *((const bigType*)(&(s[k_ + Kap * n]))); else - bigA[n][k2] = *((const bigType*)(&(A[k_ + K * n]))); + bigA[n][k2] = *((const bigType*)(&(A[k_ + Kap * n]))); #endif } } @@ -1103,22 +966,13 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) // Do the matrix multiplication in interleaved manner #pragma unroll for (uint32_t k2 = 0; k2 < UNRL; k2++) { - uint32_t k = k1 + k2 * THRDS * A_CHUNK; - uint32_t k_ = k + threadIdx.x * A_CHUNK; - if (k_ >= K) break; - #pragma unroll for (uint32_t n = 0; n < N; n++) { - // Do the matrix multiplication of activation and weight matrix - // - Remember the accumulation is happening for K-split of 64! - #pragma unroll for (int y = 0; y < YTILE; y++) { if constexpr (!use_mfma) - #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 2; b++) { DOT2C(sum[n][y], bigA[n][k2].f[b], bigB[y][k2].f[b]) } else - #pragma unroll for (uint32_t b = 0; b < A_CHUNK / 4; b++) sum4[n][y] = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k( bigA[n][k2].h4[b], bigB[y][k2].h4[b], sum4[n][y], 0, 0, 0); @@ -1141,40 +995,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) if constexpr (!use_mfma) { for (int n = 0; n < N; n++) { for (int y = 0; y < YTILE; y++) { - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:8 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:4 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shr:2 bound_ctrl:0 " - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 wave_shr:1 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(sum[n][y]) - : "0"(sum[n][y]), "v"(sum[n][y]), "v"(sum[n][y])); + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x118, 0xf, 0xf, + 1); // row_shr8 + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x114, 0xf, 0xf, + 1); // row_shr4 + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x112, 0xf, 0xf, + 1); // row_shr2 + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x111, 0xf, 0xf, + 1); // row_shr1 + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x142, 0xf, 0xf, + 1); // ROW_BCAST15 + sum[n][y] += __builtin_amdgcn_mov_dpp(sum[n][y], 0x143, 0xf, 0xf, + 1); // ROW_BCAST31 } } if (threadIdx.x == 63) { + scalar_t biases[N][YTILE] = {}; + if (BIAS) + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx]; + } + } for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) { + for (int y = 0; y < YTILE; y++) { + if (commitColumn[y]) { if constexpr (std::is_same_v) { - if (BIAS) - sum[n][i] += __half2float(BIAS[(m + i) % Bx + (n % By) * M]); + sum[n][y] += __half2float(biases[n][y]); } else if constexpr (std::is_same_v) { - if (BIAS) - sum[n][i] += - __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); + sum[n][y] += __bfloat162float(biases[n][y]); } - C[m + i + n * M] = __float2s(sum[n][i]); + C[m + y + n * M] = __float2s(sum[n][y]); } } } @@ -1185,42 +1037,38 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #pragma unroll for (int y = 0; y < YTILE; y++) { float accm = sum4[n][y][0]; - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:1 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[n][y][1]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:2 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[n][y][2]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:3 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(sum4[n][y][3]), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:4 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_shl:8 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - asm("s_nop 0\n\tv_mov_b32 %0, %2 row_shr:15 bound_ctrl:0 " - : "=v"(accm) - : "0"(accm), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:15 bound_ctrl:0" - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - asm("s_nop 0\n\tv_add_f32 %0, %2, %3 row_bcast:31 bound_ctrl:0" - : "=v"(accm) - : "0"(accm), "v"(accm), "v"(accm)); - + accm += __builtin_amdgcn_mov_dpp(sum4[n][y][1], 0x101, 0xf, 0xf, + 1); // row_shl1 + accm += __builtin_amdgcn_mov_dpp(sum4[n][y][2], 0x102, 0xf, 0xf, + 1); // row_shl2 + accm += __builtin_amdgcn_mov_dpp(sum4[n][y][3], 0x103, 0xf, 0xf, + 1); // row_shl3 + accm += __builtin_amdgcn_mov_dpp(accm, 0x104, 0xf, 0xf, + 1); // row_shl4 + accm += __builtin_amdgcn_mov_dpp(accm, 0x108, 0xf, 0xf, + 1); // row_shl8 + accm = __builtin_amdgcn_mov_dpp(accm, 0x11f, 0xf, 0xf, + 1); // row_shr15 + accm += __builtin_amdgcn_mov_dpp(accm, 0x142, 0xf, 0xf, + 1); // ROW_BCAST15 + accm += __builtin_amdgcn_mov_dpp(accm, 0x143, 0xf, 0xf, + 1); // ROW_BCAST31 sum4[n][y][0] = accm; } } if (threadIdx.x == 63) { + scalar_t biases[N][YTILE] = {}; + if (BIAS) + for (int n = 0; n < N; n++) { + for (int y = 0; y < YTILE; y++) { + biases[n][y] = BIAS[(m + y) % Bx + (n % By) * Bx]; + } + } for (int n = 0; n < N; n++) { - for (int i = 0; i < YTILE; i++) { - if (commitColumn[i]) { - if (BIAS) - sum4[n][i][0] += - __bfloat162float(BIAS[(m + i) % Bx + (n % By) * M]); - C[m + i + n * M] = __float2bfloat16(sum4[n][i][0]); + for (int y = 0; y < YTILE; y++) { + if (commitColumn[y]) { + sum4[n][y][0] += __bfloat162float(biases[n][y]); + C[m + y + n * M] = __float2bfloat16(sum4[n][y][0]); } } } @@ -1244,8 +1092,9 @@ __global__ void __launch_bounds__(WvPrGrp* THRDS) #else // !defined(__HIP__GFX9__) TODO: Add NAVI support template -__global__ void wvSplitK_hf_big_(const int K, const int M, const int Bx, - const int By, const scalar_t* B, +__global__ void wvSplitK_hf_big_(const int K, const int Kbp, const int Kap, + const int M, const int Bx, const int By, + const scalar_t* B, const scalar_t* __restrict__ A, const scalar_t* __restrict__ BIAS, scalar_t* C, const int _WvPrGrp, const int CuCount) { @@ -1272,6 +1121,8 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, auto M_in = in_a.size(0); auto K_in = in_a.size(1); auto N_in = in_b.size(0); + auto Kap_in = in_a.stride(0); + auto Kbp_in = in_b.stride(0); auto Bx_in = (in_bias.has_value() && in_bias->numel() > 0) ? (in_bias->sizes().size() == 2) ? in_bias->size(1) : in_bias->size(0) @@ -1296,27 +1147,30 @@ torch::Tensor wvSplitK(const at::Tensor& in_a, const at::Tensor& in_b, const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); const int max_lds_len = get_lds_size() / 2; -#define WVSPLITK(_YTILE, _UNRL, _N) \ - { \ - dim3 block(64, 16); \ - int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \ - if ((K_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \ - wvSplitK_hf_sml_ \ - <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ - biasf4, c, __wvPrGrp, CuCount); \ - else if (K_in * N_in <= max_lds_len * 1.2) \ - wvSplitK_hf_ \ - <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ - biasf4, c, __wvPrGrp, CuCount); \ - else \ - wvSplitK_hf_big_ \ - <<>>(K_in, M_in, Bx_in, By_in, af4, bf4, \ - biasf4, c, __wvPrGrp, CuCount); \ +#define WVSPLITK(_YTILE, _UNRL, _N) \ + { \ + dim3 block(64, 16); \ + int __wvPrGrp = mindiv(M_in, CuCount * _YTILE, 16); \ + if ((Kbp_in * N_in <= max_lds_len) && (M_in % _YTILE == 0)) \ + wvSplitK_hf_sml_ \ + <<>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \ + By_in, af4, bf4, biasf4, c, __wvPrGrp, \ + CuCount); \ + else if (Kbp_in * N_in <= max_lds_len * 1.2) \ + wvSplitK_hf_ \ + <<>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \ + By_in, af4, bf4, biasf4, c, __wvPrGrp, \ + CuCount); \ + else \ + wvSplitK_hf_big_ \ + <<>>(K_in, Kap_in, Kbp_in, M_in, Bx_in, \ + By_in, af4, bf4, biasf4, c, __wvPrGrp, \ + CuCount); \ } #define WVSPLIT_TILE(_sYT, __N) \ { \ - bool fit_lds = (K_in * N_in <= max_lds_len); \ + bool fit_lds = (Kbp_in * N_in <= max_lds_len); \ if (_sYT <= 1) \ WVSPLITK(1, 4, __N) \ else if ((__N == 1) || (!fit_lds) || (_sYT <= 4 * 2)) \ diff --git a/csrc/torch_bindings.cpp b/csrc/torch_bindings.cpp index 39b6bc98a843..f7ea8c788dd0 100644 --- a/csrc/torch_bindings.cpp +++ b/csrc/torch_bindings.cpp @@ -426,6 +426,22 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { " Tensor problem_sizes, Tensor expert_offsets, Tensor sf_offsets) -> ()"); // conditionally compiled so impl registration is in source file + // Expert-specialization mxfp8 blockscaled grouped quantization (SM100+). + ops.def( + "mxfp8_experts_quant(" + " Tensor input, Tensor problem_sizes, Tensor expert_offsets," + " Tensor blockscale_offsets, Tensor! quant_output, Tensor! scale_factor)" + " -> ()"); + // conditionally compiled so impl registration is in source file + + // Expert-specialization mxfp8 blockscaled grouped GEMM (SM100+). + ops.def( + "cutlass_mxfp8_grouped_mm(" + " Tensor a, Tensor b, Tensor sfa, Tensor sfb, Tensor! out," + " Tensor problem_sizes, Tensor expert_offsets, Tensor blockscale_offsets)" + " -> ()"); + // conditionally compiled so impl registration is in source file + // CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column // quantization, as well as bias ops.def( @@ -489,19 +505,19 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { &get_cutlass_moe_mm_problem_sizes_from_expert_offsets); // A function that computes data required to run fused MoE with w8a8 grouped - // GEMM and PPLX. It takes expert_num_tokens and non_zero_expert_idxs + // GEMM in batched expert format. It takes expert_num_tokens // as an input, and computes expert_offsets (token start indices of each // expert). In addition to this, it computes problem sizes for each expert's // multiplication used by the two mms called from fused MoE operation. ops.def( - "get_cutlass_pplx_moe_mm_data(Tensor! expert_offsets, " + "get_cutlass_batched_moe_mm_data(Tensor! expert_offsets, " " Tensor! problem_sizes1, " " Tensor! problem_sizes2, " " Tensor expert_num_tokens, " " int num_local_experts, int padded_m, " " int n, int k) -> ()"); - ops.impl("get_cutlass_pplx_moe_mm_data", torch::kCUDA, - &get_cutlass_pplx_moe_mm_data); + ops.impl("get_cutlass_batched_moe_mm_data", torch::kCUDA, + &get_cutlass_batched_moe_mm_data); // Check if cutlass scaled_mm supports block quantization (used by DeepSeekV3) ops.def( @@ -640,7 +656,9 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) { "int block_size," "Tensor? block_idx_first_scheduled_token," "Tensor? block_idx_last_scheduled_token," - "Tensor? initial_state_idx) -> ()"); + "Tensor? initial_state_idx," + "Tensor? cu_chunk_seqlen," + "Tensor? last_chunk_indices) -> ()"); ops.impl("selective_scan_fwd", torch::kCUDA, &selective_scan_fwd); // Hadamard transforms diff --git a/docker/Dockerfile b/docker/Dockerfile index cc2ccc11cdcb..495a480b7582 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -132,8 +132,10 @@ ENV UV_LINK_MODE=copy # Verify GCC version RUN gcc --version -# Ensure CUDA compatibility library is loaded -RUN echo "/usr/local/cuda-$(echo "$CUDA_VERSION" | cut -d. -f1,2)/compat/" > /etc/ld.so.conf.d/cuda-compat.conf && ldconfig +# Enable CUDA forward compatibility by setting '-e VLLM_ENABLE_CUDA_COMPATIBILITY=1' +# Only needed for datacenter/professional GPUs with older drivers. +# See: https://docs.nvidia.com/deploy/cuda-compatibility/ +ENV VLLM_ENABLE_CUDA_COMPATIBILITY=0 # ============================================================ # SLOW-CHANGING DEPENDENCIES BELOW @@ -306,7 +308,7 @@ RUN --mount=type=cache,target=/root/.cache/ccache \ #################### CSRC BUILD IMAGE #################### #################### EXTENSIONS BUILD IMAGE #################### -# Build DeepGEMM, pplx-kernels, DeepEP - runs in PARALLEL with csrc-build +# Build DeepGEMM, DeepEP - runs in PARALLEL with csrc-build # This stage is independent and doesn't affect csrc cache FROM base AS extensions-build ARG CUDA_VERSION @@ -333,10 +335,9 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # Ensure the wheel dir exists so COPY won't fail when DeepGEMM is skipped RUN mkdir -p /tmp/deepgemm/dist && touch /tmp/deepgemm/dist/.deepgemm_skipped -# Build pplx-kernels and DeepEP wheels +# Build DeepEP wheels COPY tools/ep_kernels/install_python_libraries.sh /tmp/install_python_libraries.sh # Defaults moved here from tools/ep_kernels/install_python_libraries.sh for centralized version management -ARG PPLX_COMMIT_HASH=12cecfd ARG DEEPEP_COMMIT_HASH=73b6ea4 ARG NVSHMEM_VER RUN --mount=type=cache,target=/root/.cache/uv \ @@ -345,7 +346,6 @@ RUN --mount=type=cache,target=/root/.cache/uv \ /tmp/install_python_libraries.sh \ --workspace /tmp/ep_kernels_workspace \ --mode wheel \ - ${PPLX_COMMIT_HASH:+--pplx-ref "$PPLX_COMMIT_HASH"} \ ${DEEPEP_COMMIT_HASH:+--deepep-ref "$DEEPEP_COMMIT_HASH"} \ ${NVSHMEM_VER:+--nvshmem-ver "$NVSHMEM_VER"} && \ find /tmp/ep_kernels_workspace/nvshmem -name '*.a' -delete @@ -560,8 +560,10 @@ ENV UV_HTTP_TIMEOUT=500 ENV UV_INDEX_STRATEGY="unsafe-best-match" ENV UV_LINK_MODE=copy -# Ensure CUDA compatibility library is loaded -RUN echo "/usr/local/cuda-$(echo "$CUDA_VERSION" | cut -d. -f1,2)/compat/" > /etc/ld.so.conf.d/cuda-compat.conf && ldconfig +# Enable CUDA forward compatibility by setting '-e VLLM_ENABLE_CUDA_COMPATIBILITY=1' +# Only needed for datacenter/professional GPUs with older drivers. +# See: https://docs.nvidia.com/deploy/cuda-compatibility/ +ENV VLLM_ENABLE_CUDA_COMPATIBILITY=0 # ============================================================ # SLOW-CHANGING DEPENDENCIES BELOW @@ -672,7 +674,7 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # Pytorch now installs NVSHMEM, setting LD_LIBRARY_PATH ENV LD_LIBRARY_PATH=/usr/local/cuda/lib64:$LD_LIBRARY_PATH -# Install EP kernels wheels (pplx-kernels and DeepEP) that have been built in the `build` stage +# Install EP kernels wheels (DeepEP) that have been built in the `build` stage RUN --mount=type=bind,from=build,src=/tmp/ep_kernels_workspace/dist,target=/vllm-workspace/ep_kernels/dist \ --mount=type=cache,target=/root/.cache/uv \ uv pip install --system ep_kernels/dist/*.whl --verbose \ diff --git a/docker/Dockerfile.xpu b/docker/Dockerfile.xpu index ba7dd848bdfd..3ed6de8fc722 100644 --- a/docker/Dockerfile.xpu +++ b/docker/Dockerfile.xpu @@ -6,8 +6,7 @@ ARG PYTHON_VERSION=3.12 ARG PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/xpu" RUN wget -O- https://apt.repos.intel.com/intel-gpg-keys/GPG-PUB-KEY-INTEL-SW-PRODUCTS.PUB | gpg --dearmor | tee /usr/share/keyrings/oneapi-archive-keyring.gpg > /dev/null && \ - echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list && \ - add-apt-repository -y ppa:kobuk-team/intel-graphics + echo "deb [signed-by=/usr/share/keyrings/oneapi-archive-keyring.gpg] https://apt.repos.intel.com/oneapi all main" | tee /etc/apt/sources.list.d/oneAPI.list RUN apt clean && apt-get update -y && \ apt-get install -y --no-install-recommends --fix-missing \ @@ -28,9 +27,22 @@ RUN apt clean && apt-get update -y && \ python3-pip RUN apt update && apt upgrade -y && \ - apt install -y libze1 libze-dev libze-intel-gpu1 intel-opencl-icd libze-intel-gpu-raytracing intel-ocloc && \ apt install -y intel-oneapi-compiler-dpcpp-cpp-2025.3 +# Install UMD +RUN mkdir neo && \ + cd neo && \ + wget https://github.com/intel/intel-graphics-compiler/releases/download/v2.24.8/intel-igc-core-2_2.24.8+20344_amd64.deb && \ + wget https://github.com/intel/intel-graphics-compiler/releases/download/v2.24.8/intel-igc-opencl-2_2.24.8+20344_amd64.deb && \ + wget https://github.com/intel/compute-runtime/releases/download/25.48.36300.8/intel-ocloc_25.48.36300.8-0_amd64.deb && \ + wget https://github.com/intel/compute-runtime/releases/download/25.48.36300.8/intel-opencl-icd_25.48.36300.8-0_amd64.deb && \ + wget https://github.com/intel/compute-runtime/releases/download/25.48.36300.8/libigdgmm12_22.8.2_amd64.deb && \ + wget https://github.com/intel/compute-runtime/releases/download/25.48.36300.8/libze-intel-gpu1_25.48.36300.8-0_amd64.deb && \ + wget https://github.com/oneapi-src/level-zero/releases/download/v1.26.0/level-zero_1.26.0+u24.04_amd64.deb && \ + dpkg -i *.deb && \ + cd .. && \ + rm -rf neo + ENV PATH="/root/.local/bin:$PATH" ENV VIRTUAL_ENV="/opt/venv" ENV UV_PYTHON_INSTALL_DIR=/opt/uv/python @@ -103,9 +115,57 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # install development dependencies (for testing) RUN uv pip install -e tests/vllm_test_utils -# install nixl from source code -ENV NIXL_VERSION=0.7.0 -RUN python /workspace/vllm/tools/install_nixl_from_source_ubuntu.py +# install NIXL and UCX from source code +ARG UCX_VERSION=e5d98879705239d254ede40b4a52891850cb5349 +ARG NIXL_VERSION=0.7.0 + +RUN apt-get update && apt-get install -y \ + pciutils \ + net-tools \ + iproute2 \ + hwloc \ + numactl \ + wget \ + curl \ + git \ + build-essential \ + autoconf \ + automake \ + libtool \ + pkg-config \ + rdma-core \ + libibverbs-dev \ + ibverbs-utils \ + libibverbs1 \ + librdmacm-dev \ + librdmacm1 \ + libibumad-dev \ + libibumad3 \ + libibmad-dev \ + libibmad5 \ + infiniband-diags \ + perftest \ + ibutils \ + libmlx5-1 \ + libmlx4-1 \ + ibverbs-providers \ + librdmacm1t64 + +ENV PKG_CONFIG_PATH=/tmp/ucx_install/lib/pkgconfig:${PKG_CONFIG_PATH} +ENV LD_LIBRARY_PATH=/tmp/ucx_install/lib:${LD_LIBRARY_PATH} +RUN --mount=type=cache,target=/root/.cache/uv \ + git clone https://github.com/openucx/ucx /tmp/ucx_source && \ + cd /tmp/ucx_source && git checkout "${UCX_VERSION}" && \ + bash autogen.sh && \ + ./configure --prefix=/tmp/ucx_install --with-ze=yes --enable-examples --enable-mt && \ + make CFLAGS="-Wno-error=incompatible-pointer-types" -j8 && make install && \ + git clone https://github.com/ai-dynamo/nixl /tmp/nixl_source && \ + cd /tmp/nixl_source && git checkout "${NIXL_VERSION}" && \ + cd /tmp/nixl_source && \ + uv pip install --upgrade meson pybind11 patchelf && \ + uv pip install -r requirements.txt && \ + uv pip install . && \ + rm -rf /tmp/ucx_source /tmp/nixl_source # FIX triton RUN --mount=type=cache,target=/root/.cache/uv \ diff --git a/docker/versions.json b/docker/versions.json index 24f4b6e7d1b7..fa090c10c443 100644 --- a/docker/versions.json +++ b/docker/versions.json @@ -52,9 +52,6 @@ "DEEPGEMM_GIT_REF": { "default": "477618cd51baffca09c4b0b87e97c03fe827ef03" }, - "PPLX_COMMIT_HASH": { - "default": "12cecfd" - }, "DEEPEP_COMMIT_HASH": { "default": "73b6ea4" }, diff --git a/docs/benchmarking/cli.md b/docs/benchmarking/cli.md index 7bb91239c58e..8bbd9b0c0e3e 100644 --- a/docs/benchmarking/cli.md +++ b/docs/benchmarking/cli.md @@ -4,6 +4,11 @@ This section guides you through running benchmark tests with the extensive datas It's a living document, updated as new features and datasets become available. +!!! tip + The benchmarks described on this page are mainly for evaluating specific vLLM features as well as regression testing. + + For benchmarking production vLLM servers, we recommend [GuideLLM](https://github.com/vllm-project/guidellm), an established performance benchmarking framework with live progress updates and automatic report generation. It is also more flexible than `vllm bench serve` in terms of dataset loading, request formatting, and workload patterns. + ## Dataset Overview