From 4f2ed5fddb52b45ea8a8b0dad4a354802a6b3bea Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Sun, 29 Mar 2026 21:30:26 -0500 Subject: [PATCH 0001/1601] [ROCm][CI] Enable hybrid chunked prefill test (#38317) Signed-off-by: Andreas Karatzas --- .buildkite/test-amd.yaml | 13 +++++++++++++ tests/v1/e2e/test_hybrid_chunked_prefill.py | 10 ++++++++-- 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/.buildkite/test-amd.yaml b/.buildkite/test-amd.yaml index 49987880c1a9..b7254efd2dc2 100644 --- a/.buildkite/test-amd.yaml +++ b/.buildkite/test-amd.yaml @@ -1801,6 +1801,19 @@ steps: - tests/v1/e2e commands: - pytest -v -s v1/e2e/spec_decode/test_spec_decode.py -k "eagle_correctness_heavy" + + +- label: V1 e2e (4xH100-4xMI325) # TBD + timeout_in_minutes: 180 + mirror_hardwares: [amdexperimental, amdproduction, amdgfx942nightly, amdmi325] + agent_pool: mi325_4 + optional: true + source_file_dependencies: + - vllm/v1/attention/backends/utils.py + - vllm/v1/worker/gpu_model_runner.py + - tests/v1/e2e/test_hybrid_chunked_prefill.py + commands: + - pytest -v -s v1/e2e/test_hybrid_chunked_prefill.py - label: V1 Spec Decode # TBD diff --git a/tests/v1/e2e/test_hybrid_chunked_prefill.py b/tests/v1/e2e/test_hybrid_chunked_prefill.py index 1790343ca836..dd8a5f5cb6d4 100644 --- a/tests/v1/e2e/test_hybrid_chunked_prefill.py +++ b/tests/v1/e2e/test_hybrid_chunked_prefill.py @@ -36,14 +36,20 @@ ] -@pytest.mark.skipif(not current_platform.is_cuda(), reason="CUDA not available") @pytest.mark.parametrize( "model_name", [ pytest.param("Qwen/Qwen3.5-4B", marks=[large_gpu_mark(min_gb=40)]), pytest.param( "nvidia/NVIDIA-Nemotron-3-Super-120B-A12B-FP8", - marks=[large_gpu_mark(min_gb=80)] + multi_gpu_marks(num_gpus=4), + marks=[large_gpu_mark(min_gb=80)] + + multi_gpu_marks(num_gpus=4) + + [ + pytest.mark.skipif( + not current_platform.is_cuda(), + reason="modelopt quantization is supported only on CUDA", + ) + ], ), ], ) From a6db99ba02ec67e1982619dc5bbed4b55ba4a953 Mon Sep 17 00:00:00 2001 From: Stanislav Kirillov Date: Mon, 30 Mar 2026 06:07:28 +0200 Subject: [PATCH 0002/1601] [Bugfix] Support multi-type params parsing for DeepSeek v3.2 (#33703) Signed-off-by: Stanislav Kirillov Co-authored-by: Stanislav Kirillov Co-authored-by: Chauncey --- .../test_deepseekv32_tool_parser.py | 181 ++++++++++++++++++ vllm/tool_parsers/deepseekv32_tool_parser.py | 38 ++-- 2 files changed, 201 insertions(+), 18 deletions(-) diff --git a/tests/tool_parsers/test_deepseekv32_tool_parser.py b/tests/tool_parsers/test_deepseekv32_tool_parser.py index 14462da5b9cb..0adffa7ec824 100644 --- a/tests/tool_parsers/test_deepseekv32_tool_parser.py +++ b/tests/tool_parsers/test_deepseekv32_tool_parser.py @@ -11,6 +11,7 @@ import pytest +from vllm.tokenizers import get_tokenizer from vllm.tool_parsers.deepseekv32_tool_parser import DeepSeekV32ToolParser # --------------------------------------------------------------------------- @@ -474,3 +475,183 @@ def test_no_emission_while_incomplete(self, parser): deltas = self._stream(parser, partial_text) # Should have no tool call deltas yet assert all(not d.tool_calls for d in deltas) + + +@pytest.fixture(scope="module") +def deepseekv32_tokenizer(): + return get_tokenizer(tokenizer_name="deepseek-ai/DeepSeek-V3.2") + + +@pytest.fixture +def parser(deepseekv32_tokenizer): + return DeepSeekV32ToolParser(deepseekv32_tokenizer) + + +def test_convert_param_value_single_types(parser): + """Test _convert_param_value with single type parameters.""" + # Test string type + assert parser._convert_param_value("hello", "string") == "hello" + assert parser._convert_param_value("123", "string") == "123" + + # Test integer type - valid integers + assert parser._convert_param_value("123", "integer") == 123 + assert parser._convert_param_value("456", "int") == 456 + # Invalid integer should return original string (due to exception catch) + assert parser._convert_param_value("abc", "integer") == "abc" + + # Test float/number type + assert parser._convert_param_value("123.45", "float") == 123.45 + assert ( + parser._convert_param_value("123.0", "number") == 123 + ) # Should be int when whole number + assert parser._convert_param_value("123.5", "number") == 123.5 + # Invalid float should return original string + assert parser._convert_param_value("abc", "float") == "abc" + + # Test boolean type - valid boolean values + assert parser._convert_param_value("true", "boolean") is True + assert parser._convert_param_value("false", "bool") is False + assert parser._convert_param_value("1", "boolean") is True + assert parser._convert_param_value("0", "boolean") is False + # Invalid boolean should return original string + assert parser._convert_param_value("yes", "boolean") == "yes" + assert parser._convert_param_value("no", "bool") == "no" + + # Test null value + assert parser._convert_param_value("null", "string") is None + assert parser._convert_param_value("null", "integer") is None + + # Test object/array type (JSON) + assert parser._convert_param_value('{"key": "value"}', "object") == {"key": "value"} + assert parser._convert_param_value("[1, 2, 3]", "array") == [1, 2, 3] + # Invalid JSON should return original string + assert parser._convert_param_value("{invalid}", "object") == "{invalid}" + + # Test fallback for unknown type (tries json.loads, then returns original) + assert parser._convert_param_value('{"key": "value"}', "unknown") == { + "key": "value" + } + assert parser._convert_param_value("plain text", "unknown") == "plain text" + + +def test_convert_param_value_multi_typed_values(parser): + """Test _convert_param_value with multi-typed values (list of types).""" + # Test with list of types where first type succeeds + assert parser._convert_param_value("123", ["integer", "string"]) == 123 + assert parser._convert_param_value("true", ["boolean", "string"]) is True + assert parser._convert_param_value('{"x": 1}', ["object", "string"]) == {"x": 1} + + # Test with list of types where first type fails but second succeeds + # "abc" is not a valid integer, so should try string next + assert parser._convert_param_value("abc", ["integer", "string"]) == "abc" + + # Test with list of types where all fail - should return original value + # "invalid json" is not valid JSON, last type is "object" which will fail JSON parse + result = parser._convert_param_value("invalid json", ["integer", "object"]) + assert result == "invalid json" # Returns original value after all types fail + + # Test with three types + assert parser._convert_param_value("123.5", ["integer", "float", "string"]) == 123.5 + assert parser._convert_param_value("true", ["integer", "boolean", "string"]) is True + + # Test with null in multi-type list + assert parser._convert_param_value("null", ["integer", "string"]) is None + assert parser._convert_param_value("null", ["boolean", "object"]) is None + + # Test nested type conversion - boolean fails, integer succeeds + value = parser._convert_param_value("123", ["boolean", "integer", "string"]) + assert value == 123 # Should be integer, not boolean + + # Test that order matters + assert ( + parser._convert_param_value("123", ["string", "integer"]) == "123" + ) # String first + assert ( + parser._convert_param_value("123", ["integer", "string"]) == 123 + ) # Integer first + + # Test with all types failing - returns original value + assert ( + parser._convert_param_value("not_a_number", ["integer", "float", "boolean"]) + == "not_a_number" + ) + + +def test_convert_param_value_stricter_type_checking(parser): + """Test stricter type checking in the updated implementation.""" + # Boolean now has stricter validation + assert parser._convert_param_value("true", "boolean") is True + assert parser._convert_param_value("false", "boolean") is False + assert parser._convert_param_value("1", "boolean") is True + assert parser._convert_param_value("0", "boolean") is False + + # These should return original string (not valid boolean values) + assert parser._convert_param_value("yes", "boolean") == "yes" + assert parser._convert_param_value("no", "boolean") == "no" + assert parser._convert_param_value("TRUE", "boolean") is True + assert parser._convert_param_value("FALSE", "boolean") is False + + # Integer and float now raise exceptions for invalid values + assert parser._convert_param_value("123abc", "integer") == "123abc" + assert parser._convert_param_value("123.45.67", "float") == "123.45.67" + + # JSON parsing is stricter - invalid JSON returns original + assert parser._convert_param_value("{invalid: json}", "object") == "{invalid: json}" + assert parser._convert_param_value("[1, 2,", "array") == "[1, 2," + + # Test multi-type with stricter checking + # "yes" is not valid boolean, but string would accept it + assert parser._convert_param_value("yes", ["boolean", "string"]) == "yes" + + # "123abc" is not valid integer or float, but string accepts it + assert ( + parser._convert_param_value("123abc", ["integer", "float", "string"]) + == "123abc" + ) + + +def test_convert_param_value_edge_cases(parser): + """Test edge cases for _convert_param_value.""" + # Empty string + assert parser._convert_param_value("", "string") == "" + assert ( + parser._convert_param_value("", "integer") == "" + ) # Invalid int returns original + + # Whitespace - trimmed by conversion functions + assert parser._convert_param_value(" 123 ", "integer") == 123 + assert parser._convert_param_value(" true ", "boolean") is True + + # Numeric strings with special characters + assert parser._convert_param_value("123.45.67", "float") == "123.45.67" + assert parser._convert_param_value("123abc", "integer") == "123abc" + + # JSON with whitespace - should parse correctly + assert parser._convert_param_value(' { "key" : "value" } ', "object") == { + "key": "value" + } + + # Invalid JSON returns original + assert parser._convert_param_value("{invalid}", "object") == "{invalid}" + assert parser._convert_param_value("[1, 2,", "array") == "[1, 2," + + +def test_convert_param_value_checked_helper(parser): + """Test the _convert_param_value_checked helper function indirectly.""" + # This tests the behavior through the main function + # Valid conversions should work + assert parser._convert_param_value("123", "integer") == 123 + assert parser._convert_param_value("123.45", "float") == 123.45 + assert parser._convert_param_value("true", "boolean") is True + assert parser._convert_param_value('{"x": 1}', "object") == {"x": 1} + + # Invalid conversions should return original value (exception caught) + assert parser._convert_param_value("abc", "integer") == "abc" + assert parser._convert_param_value("abc", "float") == "abc" + assert parser._convert_param_value("yes", "boolean") == "yes" + assert parser._convert_param_value("{invalid}", "object") == "{invalid}" + + # Test that null handling works in checked function + assert parser._convert_param_value("null", "integer") is None + assert parser._convert_param_value("null", "boolean") is None + assert parser._convert_param_value("null", "object") is None diff --git a/vllm/tool_parsers/deepseekv32_tool_parser.py b/vllm/tool_parsers/deepseekv32_tool_parser.py index e86929944799..a9772b753134 100644 --- a/vllm/tool_parsers/deepseekv32_tool_parser.py +++ b/vllm/tool_parsers/deepseekv32_tool_parser.py @@ -100,7 +100,7 @@ def _parse_invoke_params(self, invoke_str: str) -> dict: param_dict[param_name] = param_val return param_dict - def _convert_param_value(self, value: str, param_type: str) -> Any: + def _convert_param_value_checked(self, value: str, param_type: str) -> Any: """Convert parameter value to the correct type.""" if value.lower() == "null": return None @@ -109,29 +109,31 @@ def _convert_param_value(self, value: str, param_type: str) -> Any: if param_type in ["string", "str", "text"]: return value elif param_type in ["integer", "int"]: - try: - return int(value) - except (ValueError, TypeError): - return value + return int(value) elif param_type in ["number", "float"]: - try: - val = float(value) - return val if val != int(val) else int(val) - except (ValueError, TypeError): - return value + val = float(value) + return val if val != int(val) else int(val) elif param_type in ["boolean", "bool"]: + value = value.strip() + if value.lower() not in ["false", "0", "true", "1"]: + raise ValueError("Invalid boolean value") return value.lower() in ["true", "1"] elif param_type in ["object", "array"]: - try: - return json.loads(value) - except json.JSONDecodeError: - return value + return json.loads(value) else: - # Try JSON parse first, fallback to string + return json.loads(value) + + def _convert_param_value(self, value: str, param_type: str | list[str]) -> Any: + """Convert parameter value to the correct type.""" + if not isinstance(param_type, list): + param_type = [param_type] + for current_type in param_type: try: - return json.loads(value) - except json.JSONDecodeError: - return value + return self._convert_param_value_checked(value, current_type) + except Exception: + continue + # return value as fallback + return value def _convert_params_with_schema( self, From c133f3374625652c88e122fff995e4126c4635c0 Mon Sep 17 00:00:00 2001 From: Jiangyun Zhu Date: Mon, 30 Mar 2026 12:10:00 +0800 Subject: [PATCH 0003/1601] Add @ZJY0516 to CODEOWNERS (#38497) Signed-off-by: zjy0516 --- .github/CODEOWNERS | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index c0ceae044d25..37707da8af26 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -9,6 +9,7 @@ /vllm/model_executor/layers/fused_moe @mgoin @pavanimajety /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety /vllm/model_executor/layers/mamba @tdoublep +/vllm/model_executor/layers/mamba/gdn_linear_attn.py @tdoublep @ZJY0516 /vllm/model_executor/model_loader @22quinn /vllm/model_executor/layers/batch_invariant.py @yewentao256 /vllm/multimodal @DarkLight1337 @ywang96 @NickLucche @tjtanaa @@ -48,6 +49,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /vllm/v1/attention/backends/mla @pavanimajety /vllm/v1/attention/backends/flashinfer.py @mgoin @pavanimajety /vllm/v1/attention/backends/triton_attn.py @tdoublep +/vllm/v1/attention/backends/gdn_attn.py @ZJY0516 /vllm/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery /vllm/v1/sample @22quinn @houseroad @njhill /vllm/v1/spec_decode @benchislett @luccafong @MatthewBonanni @@ -142,6 +144,7 @@ mkdocs.yaml @hmellor # Kernels /vllm/v1/attention/ops/chunked_prefill_paged_decode.py @tdoublep /vllm/v1/attention/ops/triton_unified_attention.py @tdoublep +/vllm/model_executor/layers/fla @ZJY0516 # ROCm related: specify owner with write access to notify AMD folks for careful code review /vllm/**/*rocm* @tjtanaa From bea23536f627b2b4153f2e672753b6034b78dedb Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Mon, 30 Mar 2026 00:36:45 -0500 Subject: [PATCH 0004/1601] [CI] Add temperature=0.0, reduce max_tokens, and add debug prints to audio_in_video tests (#38492) Signed-off-by: Andreas Karatzas --- .../chat_completion/test_audio_in_video.py | 22 +++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/tests/entrypoints/openai/chat_completion/test_audio_in_video.py b/tests/entrypoints/openai/chat_completion/test_audio_in_video.py index 8c024995b938..61ee91eab4d0 100644 --- a/tests/entrypoints/openai/chat_completion/test_audio_in_video.py +++ b/tests/entrypoints/openai/chat_completion/test_audio_in_video.py @@ -64,11 +64,12 @@ async def test_online_audio_in_video( ] # multi-turn to test mm processor cache as well - for _ in range(2): + for turn in range(2): chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, - max_tokens=16, + max_tokens=8, + temperature=0.0, extra_body={ "mm_processor_kwargs": { "use_audio_in_video": True, @@ -78,6 +79,12 @@ async def test_online_audio_in_video( assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] + print( + f"[DEBUG][single-video] turn={turn} " + f"finish_reason={choice.finish_reason!r} " + f"content={choice.message.content!r} " + f"usage={chat_completion.usage}" + ) assert choice.finish_reason == "length" @@ -111,11 +118,12 @@ async def test_online_audio_in_video_multi_videos( ] # multi-turn to test mm processor cache as well - for _ in range(2): + for turn in range(2): chat_completion = await client.chat.completions.create( model=MODEL_NAME, messages=messages, - max_tokens=16, + max_tokens=8, + temperature=0.0, extra_body={ "mm_processor_kwargs": { "use_audio_in_video": True, @@ -125,6 +133,12 @@ async def test_online_audio_in_video_multi_videos( assert len(chat_completion.choices) == 1 choice = chat_completion.choices[0] + print( + f"[DEBUG][multi-video] turn={turn} " + f"finish_reason={choice.finish_reason!r} " + f"content={choice.message.content!r} " + f"usage={chat_completion.usage}" + ) assert choice.finish_reason == "length" From 92f0db57a88161af6e76728e5f552a24e6955569 Mon Sep 17 00:00:00 2001 From: Roger Wang Date: Sun, 29 Mar 2026 22:39:41 -0700 Subject: [PATCH 0005/1601] [Misc] Always use `forward_mulmat` for `Conv3d` on newer versions of torch. (#38487) --- vllm/model_executor/layers/conv.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/vllm/model_executor/layers/conv.py b/vllm/model_executor/layers/conv.py index f4709f2f4d80..51314263b735 100644 --- a/vllm/model_executor/layers/conv.py +++ b/vllm/model_executor/layers/conv.py @@ -10,7 +10,7 @@ import torch.nn.functional as F from vllm.model_executor.custom_op import CustomOp -from vllm.utils.torch_utils import is_torch_equal +from vllm.utils.torch_utils import is_torch_equal_or_newer class ConvLayerBase(CustomOp): @@ -252,11 +252,12 @@ def forward_native(self, x: torch.Tensor) -> torch.Tensor: return self._forward_conv(x) def forward_cuda(self, x: torch.Tensor) -> torch.Tensor: - # PyTorch2.9.0 disabled CUDNN's Conv3D, which caused a + # PyTorch 2.9.0+ disabled CUDNN's Conv3D, which caused a # significant performance regression. # See: https://github.com/vllm-project/vllm/issues/27406 # and https://github.com/pytorch/pytorch/issues/166122 + # and https://github.com/huggingface/transformers/pull/45041 # By default, we use CUDNN's convolution ops with optimization. - if self.enable_linear and (is_torch_equal("2.9.0") or is_torch_equal("2.9.1")): + if self.enable_linear and is_torch_equal_or_newer("2.9.0"): return self._forward_mulmat(x) return self._forward_conv(x) From d816834c1aa61da8e8703c6e8c368ad22fe76e13 Mon Sep 17 00:00:00 2001 From: Jaewon <52840625+jaewonlee-fb@users.noreply.github.com> Date: Sun, 29 Mar 2026 22:53:43 -0700 Subject: [PATCH 0006/1601] [MoE] Add RoutingMethodType.Simulated to TRT-LLM FP8/NVFP4 kernel allowlists (#38329) Signed-off-by: Jaewon Lee --- .../layers/fused_moe/experts/trtllm_fp8_moe.py | 15 +++++++++++---- .../layers/fused_moe/experts/trtllm_nvfp4_moe.py | 14 ++++++++++---- 2 files changed, 21 insertions(+), 8 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index 9a6f67b421f9..c0a7dfc4978a 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -256,13 +256,18 @@ def _supports_router_logits_dtype( ) -> bool: """ The FlashInfer TRTLLM FP8 kernel expects bfloat16 router_logits by default. - Only DeepSeekV3 routing supports float32 router_logits (which is converted - internally in the kernel). + DeepSeekV3 routing supports float32 router_logits (converted internally). + Simulated routing generates synthetic decisions and is agnostic to dtype. """ if router_logits_dtype == torch.float32: - # Only DeepSeekV3 routing handles float32 logits + # DeepSeekV3 routing handles float32 logits internally. + # Simulated routing generates synthetic decisions, so the + # kernel doesn't care about the actual logits dtype. # https://github.com/flashinfer-ai/flashinfer/issues/2469 - return routing_method == RoutingMethodType.DeepSeekV3 + return routing_method in ( + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Simulated, + ) return True @staticmethod @@ -288,12 +293,14 @@ def _supports_routing_method( # NOTE(rob): potentially allow others here. This is a conservative list. return routing_method in [ RoutingMethodType.DeepSeekV3, + RoutingMethodType.Simulated, ] elif (weight_key, activation_key) == (kFp8StaticTensorSym, kFp8StaticTensorSym): # NOTE(dbari): as above, potentially allow others here. return routing_method in [ RoutingMethodType.DeepSeekV3, RoutingMethodType.Llama4, + RoutingMethodType.Simulated, ] else: raise ValueError("Unsupported quantization scheme.") diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py index 84beb6abb553..b47391c41ea0 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py @@ -255,6 +255,7 @@ def _supports_routing_method( RoutingMethodType.Renormalize, RoutingMethodType.RenormalizeNaive, RoutingMethodType.Llama4, + RoutingMethodType.Simulated, ] @staticmethod @@ -264,13 +265,18 @@ def _supports_router_logits_dtype( ) -> bool: """ The FlashInfer TRTLLM NvFp4 kernel expects bfloat16 router_logits by default. - Only DeepSeekV3 routing supports float32 router_logits (which is converted - internally in the kernel). + DeepSeekV3 routing supports float32 router_logits (converted internally). + Simulated routing generates synthetic decisions and is agnostic to dtype. """ if router_logits_dtype == torch.float32: - # Only DeepSeekV3 routing handles float32 logits + # DeepSeekV3 routing handles float32 logits internally. + # Simulated routing generates synthetic decisions, so the + # kernel doesn't care about the actual logits dtype. # https://github.com/flashinfer-ai/flashinfer/issues/2469 - return routing_method == RoutingMethodType.DeepSeekV3 + return routing_method in ( + RoutingMethodType.DeepSeekV3, + RoutingMethodType.Simulated, + ) return True def apply( From fec5aeca12d377b3ab81ba9fb8a0b1972a768f93 Mon Sep 17 00:00:00 2001 From: "Kevin H. Luu" Date: Sun, 29 Mar 2026 23:05:26 -0700 Subject: [PATCH 0007/1601] [ci] Soft fail and disable retry for AMD build image job (#38505) Signed-off-by: Kevin H. Luu --- .buildkite/hardware_tests/amd.yaml | 9 +-------- 1 file changed, 1 insertion(+), 8 deletions(-) diff --git a/.buildkite/hardware_tests/amd.yaml b/.buildkite/hardware_tests/amd.yaml index 23a23723ad93..b2c0283ecea9 100644 --- a/.buildkite/hardware_tests/amd.yaml +++ b/.buildkite/hardware_tests/amd.yaml @@ -5,6 +5,7 @@ steps: depends_on: [] device: amd_cpu no_plugin: true + soft_fail: true commands: - > docker build @@ -20,11 +21,3 @@ steps: - docker push "rocm/vllm-ci:${BUILDKITE_COMMIT}" env: DOCKER_BUILDKIT: "1" - retry: - automatic: - - exit_status: -1 # Agent was lost - limit: 1 - - exit_status: -10 # Agent was lost - limit: 1 - - exit_status: 1 # Machine occasionally fail - limit: 1 From 63babd17f1b110e267e1ad801a9b9d4ccf5bbe7d Mon Sep 17 00:00:00 2001 From: PikaPikachu Date: Mon, 30 Mar 2026 14:24:06 +0800 Subject: [PATCH 0008/1601] [Model][Quantization] Add GGUF support for MiniMax-M2.1 (#36965) Signed-off-by: kangletian --- vllm/config/model.py | 1 + .../layers/quantization/gguf.py | 15 +++- .../model_loader/gguf_loader.py | 81 +++++++++++++++++-- .../model_loader/weight_utils.py | 43 ++++++++++ vllm/model_executor/models/minimax_m2.py | 7 +- 5 files changed, 137 insertions(+), 10 deletions(-) diff --git a/vllm/config/model.py b/vllm/config/model.py index b8c601334a1d..c4ee654fe8bc 100644 --- a/vllm/config/model.py +++ b/vllm/config/model.py @@ -948,6 +948,7 @@ def _verify_quantization(self) -> None: # imports during override detection (e.g., MXFP4 imports Triton) "mxfp4", "cpu_awq", + "gguf", ] quantization_methods = [ q for q in supported_quantization if q not in overrides diff --git a/vllm/model_executor/layers/quantization/gguf.py b/vllm/model_executor/layers/quantization/gguf.py index 88023349e779..145610e9c0d2 100644 --- a/vllm/model_executor/layers/quantization/gguf.py +++ b/vllm/model_executor/layers/quantization/gguf.py @@ -3,7 +3,10 @@ from collections.abc import Mapping from types import MappingProxyType -from typing import Any +from typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from vllm.model_executor.layers.quantization import QuantizationMethods import gguf import torch @@ -79,6 +82,16 @@ def get_config_filenames(cls) -> list[str]: def from_config(cls, config: dict[str, Any]) -> "GGUFConfig": return cls() + @classmethod + def override_quantization_method( + cls, hf_quant_cfg: dict[str, Any], user_quant: str | None + ) -> "QuantizationMethods | None": + # When user explicitly specifies --quantization gguf, override + # whatever quantization method is in the HF model config (e.g. fp8). + if user_quant == "gguf": + return "gguf" + return None + def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None": diff --git a/vllm/model_executor/model_loader/gguf_loader.py b/vllm/model_executor/model_loader/gguf_loader.py index 75d0b3425460..ce6a813b8da5 100644 --- a/vllm/model_executor/model_loader/gguf_loader.py +++ b/vllm/model_executor/model_loader/gguf_loader.py @@ -24,6 +24,7 @@ get_gguf_extra_tensor_names, get_gguf_weight_type_map, gguf_quant_weights_iterator, + gguf_quant_weights_iterator_multi, ) from vllm.transformers_utils.gguf_utils import detect_gguf_multimodal from vllm.utils.torch_utils import set_default_torch_dtype @@ -74,6 +75,31 @@ def _prepare_weights(self, model_config: ModelConfig): "or :)" ) + @staticmethod + def _get_all_gguf_files(model_path: str) -> list[str]: + """Discover all GGUF shard files from a single shard path. + + Supports variable-width shard indices by dynamically detecting + the padding from the original filename. + E.g. ``*-00001-of-00005.gguf`` → all 5 shards, + ``*-01-of-15.gguf`` → all 15 shards. + """ + match = re.search(r"-(\d+)-of-(\d+)\.gguf$", model_path) + if not match: + return [model_path] + total = int(match.group(2)) + num_digits = len(match.group(1)) + prefix = model_path[: match.start(1)] + suffix = model_path[match.end(2) :] + files = [] + for i in range(1, total + 1): + shard_path = f"{prefix}{i:0{num_digits}d}-of-{total:0{num_digits}d}{suffix}" + if os.path.isfile(shard_path): + files.append(shard_path) + if files: + logger.info("Discovered %d GGUF shard files", len(files)) + return files if files else [model_path] + def _get_gguf_weights_map(self, model_config: ModelConfig): """ GGUF uses this naming convention for their tensors from HF checkpoint: @@ -145,6 +171,29 @@ def _get_gguf_weights_map(self, model_config: ModelConfig): r"\.mlp\.experts\.[0-9]+\.(gate|up|down)_proj\.weight" ) ) + if model_type == "minimax_m2": + model_type = "minimax-m2" + # GGUF layer map assumes merged expert weights + # map them manually like deepseek2 + for idx in range(config.num_hidden_layers): + gguf_to_hf_name_map[f"blk.{idx}.exp_probs_b.bias"] = ( + f"model.layers.{idx}.block_sparse_moe.e_score_correction_bias" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_down_exps.weight"] = ( + f"model.layers.{idx}.block_sparse_moe.experts.0.w2.weight" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_gate_exps.weight"] = ( + f"model.layers.{idx}.block_sparse_moe.experts.0.w1.weight" + ) + gguf_to_hf_name_map[f"blk.{idx}.ffn_up_exps.weight"] = ( + f"model.layers.{idx}.block_sparse_moe.experts.0.w3.weight" + ) + sideload_params.append( + re.compile( + f"model\\.layers\\.{idx}" + r"\.block_sparse_moe\.experts\.(gate_up_proj|down_proj)" + ) + ) arch = None for key, value in gguf.MODEL_ARCH_NAMES.items(): @@ -190,6 +239,13 @@ def revert_hf_rename(name: str) -> str: revert_hf_rename(name): tensor for name, tensor in state_dict.items() } + if model_type == "minimax-m2" and not hf_checkpoint_map: + # Reverse HF convention: mlp -> block_sparse_moe + state_dict = { + name.replace(".mlp.", ".block_sparse_moe."): tensor + for name, tensor in state_dict.items() + } + def find_hf_name_in_tensor_map(hf_name: str) -> str | None: """ Map HuggingFace parameter name to GGUF tensor name. @@ -277,9 +333,10 @@ def _get_gguf_weight_type( model_name_or_path: str, gguf_to_hf_name_map: dict[str, str], ) -> dict[str, str]: - weight_type_map = get_gguf_weight_type_map( - model_name_or_path, gguf_to_hf_name_map - ) + gguf_files = self._get_all_gguf_files(model_name_or_path) + weight_type_map = {} + for f in gguf_files: + weight_type_map.update(get_gguf_weight_type_map(f, gguf_to_hf_name_map)) is_multimodal = hasattr(model_config.hf_config, "vision_config") if is_multimodal: mmproj_file = detect_gguf_multimodal(model_name_or_path) @@ -321,7 +378,15 @@ def _get_weights_iterator( ) yield from gguf_quant_weights_iterator(mmproj_file, gguf_to_hf_name_map) - yield from gguf_quant_weights_iterator(model_name_or_path, gguf_to_hf_name_map) + gguf_files = self._get_all_gguf_files(model_name_or_path) + if len(gguf_files) > 1: + yield from gguf_quant_weights_iterator_multi( + gguf_files, gguf_to_hf_name_map + ) + else: + yield from gguf_quant_weights_iterator( + model_name_or_path, gguf_to_hf_name_map + ) def download_model(self, model_config: ModelConfig) -> None: self._prepare_weights(model_config) @@ -340,9 +405,11 @@ def load_model( local_model_path = self._prepare_weights(model_config) gguf_weights_map = self._get_gguf_weights_map(model_config) # we can only know if tie word embeddings after mapping weights - if "lm_head.weight" in get_gguf_extra_tensor_names( - local_model_path, gguf_weights_map - ): + gguf_files = self._get_all_gguf_files(local_model_path) + all_extra_names = [] + for f in gguf_files: + all_extra_names.extend(get_gguf_extra_tensor_names(f, gguf_weights_map)) + if "lm_head.weight" in all_extra_names: model_config.hf_config.update({"tie_word_embeddings": True}) weight_type_map = self._get_gguf_weight_type( diff --git a/vllm/model_executor/model_loader/weight_utils.py b/vllm/model_executor/model_loader/weight_utils.py index 37023d3f1f5c..bd1dd20013f6 100644 --- a/vllm/model_executor/model_loader/weight_utils.py +++ b/vllm/model_executor/model_loader/weight_utils.py @@ -1222,6 +1222,49 @@ def gguf_quant_weights_iterator( yield name, param +def gguf_quant_weights_iterator_multi( + gguf_files: list[str], gguf_to_hf_name_map: dict[str, str] +) -> Generator[tuple[str, torch.Tensor], None, None]: + """ + Iterate over the quant weights across multiple GGUF shard files + and convert them to torch tensors. + + Like gguf_quant_weights_iterator, we yield all weight types first + before yielding any weights data to avoid issues with packed layers + that have different quant types. + """ + readers = [gguf.GGUFReader(f) for f in gguf_files] + + # First pass: yield all weight types across all shards + for reader in readers: + for tensor in reader.tensors: + if tensor.name in gguf_to_hf_name_map: + weight_type = tensor.tensor_type + name = gguf_to_hf_name_map[tensor.name] + if weight_type.name not in ("F32", "BF16", "F16"): + weight_type_name = name.replace("weight", "qweight_type") + weight_type = torch.tensor(weight_type) + yield weight_type_name, weight_type + + # Second pass: yield all weight data across all shards + for reader in readers: + for tensor in reader.tensors: + if tensor.name in gguf_to_hf_name_map: + weight = tensor.data + weight_type = tensor.tensor_type + name = gguf_to_hf_name_map[tensor.name] + if weight_type.name not in ("F32", "BF16", "F16"): + name = name.replace("weight", "qweight") + if weight_type.name == "BF16" and tensor.data.dtype == np.uint8: + weight = weight.view(np.uint16) + if reader.byte_order == "S": + weight = weight.byteswap() + param = torch.tensor(weight).view(torch.bfloat16) + else: + param = torch.tensor(weight) + yield name, param + + def convert_pyslice_to_tensor(x: Any) -> torch.Tensor: """convert PySafeSlice object from safetensors to torch.Tensor diff --git a/vllm/model_executor/models/minimax_m2.py b/vllm/model_executor/models/minimax_m2.py index 426caea1f047..0f43bc0cdcec 100644 --- a/vllm/model_executor/models/minimax_m2.py +++ b/vllm/model_executor/models/minimax_m2.py @@ -331,7 +331,7 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, - quant_config=None, + quant_config=quant_config, prefix=f"{prefix}.embed_tokens", ) else: @@ -518,7 +518,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) if get_pp_group().is_last_rank: self.lm_head = ParallelLMHead( - config.vocab_size, config.hidden_size, quant_config=None + config.vocab_size, + config.hidden_size, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "lm_head"), ) else: self.lm_head = PPMissingLayer() From ac30a8311ec286f85ddd1b8f8773ccfbfcf8ba08 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Mon, 30 Mar 2026 14:59:42 +0800 Subject: [PATCH 0009/1601] [Bugfix][Model] Fix PixtralForConditionalGeneration LoRA (#36963) Signed-off-by: Jee Jee Li Co-authored-by: Roger Wang --- vllm/model_executor/models/pixtral.py | 235 ++++++++++++++++++++------ 1 file changed, 183 insertions(+), 52 deletions(-) diff --git a/vllm/model_executor/models/pixtral.py b/vllm/model_executor/models/pixtral.py index 0d891b8c9f20..e179638a869b 100644 --- a/vllm/model_executor/models/pixtral.py +++ b/vllm/model_executor/models/pixtral.py @@ -8,7 +8,6 @@ import torch import torch.nn as nn -import torch.nn.functional as F from mistral_common.protocol.instruct.chunk import ImageChunk, TextChunk from mistral_common.protocol.instruct.messages import UserMessage from mistral_common.protocol.instruct.request import ChatCompletionRequest @@ -26,16 +25,18 @@ from vllm.config.multimodal import BaseDummyOptions from vllm.distributed import divide, get_tensor_model_parallel_world_size from vllm.inputs import MultiModalDataDict -from vllm.model_executor.layers.activation import get_act_and_mul_fn +from vllm.model_executor.layers.activation import SiluAndMul, get_act_and_mul_fn from vllm.model_executor.layers.conv import Conv2dLayer from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, + ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.model_loader.weight_utils import default_weight_loader +from vllm.model_executor.models.utils import WeightsMapper from vllm.multimodal import MULTIMODAL_REGISTRY, MultiModalKwargsItems from vllm.multimodal.inputs import ( MultiModalFieldConfig, @@ -293,6 +294,23 @@ def _cached_apply_hf_processor( class PixtralForConditionalGeneration( nn.Module, SupportsLoRA, SupportsEagle3, SupportsMultiModal, SupportsPP ): + hf_to_vllm_mapper = WeightsMapper( + orig_to_new_prefix={ + "model.language_model.": "language_model.model.", + "model.vision_tower.": "vision_encoder.", + "model.multi_modal_projector.": "vision_language_adapter.", + }, + orig_to_new_substr={ + ".linear_1.": ".w_in.", + ".linear_2.": ".w_out.", + }, + ) + + packed_modules_mapping = { + "qkv_proj": ["q_proj", "k_proj", "v_proj"], + "gate_up_proj": ["gate_proj", "up_proj"], + } + @classmethod def get_placeholder_str(cls, modality: str, i: int) -> str | None: if modality.startswith("image"): @@ -325,7 +343,10 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): ) with self._mark_tower_model(vllm_config, "image"): - self.vision_encoder = VisionTransformer(self.vision_args) + self.vision_encoder = VisionTransformer( + self.vision_args, + prefix=maybe_prefix(prefix, "vision_encoder"), + ) self.pre_mm_projector_norm = ( RMSNorm(self.vision_args.hidden_size, eps=1e-5) if self.vision_args.add_pre_mm_projector_layer_norm @@ -435,6 +456,15 @@ def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: return self.language_model.get_eagle3_aux_hidden_state_layers() def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]): + _vision_encoder_stacked_params = [ + # (param_name, shard_name, shard_id) + (".qkv_proj", ".q_proj", "q"), + (".qkv_proj", ".k_proj", "k"), + (".qkv_proj", ".v_proj", "v"), + (".gate_up_proj", ".gate_proj", 0), + (".gate_up_proj", ".up_proj", 1), + ] + def is_vision_encoder_weights(weight: tuple[str, torch.Tensor]): return weight[0].startswith(("vision_encoder", "vision_tower")) @@ -449,7 +479,6 @@ def is_patch_merger(weight: tuple[str, torch.Tensor]): def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]): return weight[0].startswith("pre_mm_projector_norm") - # Get references to parameters for direct loading vision_encoder_dict = ( dict(self.vision_encoder.named_parameters()) if self.vision_encoder is not None @@ -472,29 +501,41 @@ def is_pre_mm_projector_norm(weight: tuple[str, torch.Tensor]): ) def llm_weights_generator(): - # Single pass over weights for name, w in weights: if is_vision_encoder_weights((name, w)): if _is_layer_none_or_staged(self.vision_encoder): continue - # Load vision encoder weights directly trimmed_name = ".".join(name.split(".")[1:]) - param = vision_encoder_dict.get(trimmed_name) - if param is not None: - with torch.no_grad(): - default_weight_loader(param, w) + for ( + param_name, + weight_name, + shard_id, + ) in _vision_encoder_stacked_params: + if weight_name in trimmed_name: + trimmed_name = trimmed_name.replace(weight_name, param_name) + param = vision_encoder_dict[trimmed_name] + weight_loader = param.weight_loader + weight_loader(param, w, shard_id) + break + else: + param = vision_encoder_dict.get(trimmed_name) + if param is not None: + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, w) elif is_patch_merger((name, w)): if _is_layer_none_or_staged(self.patch_merger): continue - # Load vision patch merger weights directly trimmed_name = ".".join(name.split(".")[1:]) param = patch_merger_dict[trimmed_name] - with torch.no_grad(): - default_weight_loader(param, w) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, w) elif is_pre_mm_projector_norm((name, w)): if _is_layer_none_or_staged(self.pre_mm_projector_norm): continue - # Load vision pre_mm_projector_norm weights directly trimmed_name = ".".join(name.split(".")[1:]) param = pre_mm_projector_norm_dict[trimmed_name] with torch.no_grad(): @@ -502,26 +543,23 @@ def llm_weights_generator(): elif is_vision_lang_adapter_weights((name, w)): if _is_layer_none_or_staged(self.vision_language_adapter): continue - # Load vision-language adapter weights directly trimmed_name = ".".join(name.split(".")[1:]) param = vision_lang_adapter_dict.get(trimmed_name) if param is not None: - with torch.no_grad(): - default_weight_loader(param, w) + weight_loader = getattr( + param, "weight_loader", default_weight_loader + ) + weight_loader(param, w) else: - # LLM weights: yield them to be loaded - # by language_model.load_weights - # Strip "language_model." prefix if present (HF sharded format) name = name.removeprefix("language_model.") yield (name, w) - # Now we call the language model load with the generator self.language_model.load_weights(llm_weights_generator()) def get_mm_mapping(self) -> MultiModelKeys: return MultiModelKeys.from_string_field( - language_model="language_model", - connector="vision_language_adapter", + language_model="language_model.", + connector="vision_language_adapter.", tower_model="vision_encoder", ) @@ -614,29 +652,78 @@ def apply_rotary_emb_vit( class FeedForward(nn.Module): - def __init__(self, args: VisionEncoderArgs): + def __init__( + self, + hidden_size: int, + intermediate_size: int, + quant_config: QuantizationConfig | None = None, + bias: bool = False, + prefix: str = "", + reduce_results: bool = True, + disable_tp: bool = False, + ) -> None: super().__init__() - assert args.intermediate_size is not None - self.w1 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) - self.w2 = nn.Linear(args.intermediate_size, args.hidden_size, bias=False) - self.w3 = nn.Linear(args.hidden_size, args.intermediate_size, bias=False) - def forward(self, x: torch.Tensor) -> torch.Tensor: - return self.w2(F.silu(self.w1(x)) * self.w3(x)) + self.gate_up_proj = MergedColumnParallelLinear( + input_size=hidden_size, + output_sizes=[intermediate_size] * 2, + bias=bias, + quant_config=quant_config, + disable_tp=disable_tp, + prefix=f"{prefix}.w13", + ) + self.down_proj = RowParallelLinear( + input_size=intermediate_size, + output_size=hidden_size, + bias=bias, + quant_config=quant_config, + reduce_results=reduce_results, + disable_tp=disable_tp, + prefix=f"{prefix}.w2", + ) + + self.act_fn = SiluAndMul() + + def forward(self, x): + x, _ = self.gate_up_proj(x) + x = self.act_fn(x) + x, _ = self.down_proj(x) + return x class Attention(nn.Module): - def __init__(self, args: VisionEncoderArgs): + def __init__( + self, + args: VisionEncoderArgs, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + disable_tp: bool = False, + ): super().__init__() self.args = args assert not args.hidden_size % args.num_attention_heads - self.n_heads = args.num_attention_heads self.head_dim = args.hidden_size // args.num_attention_heads - self.wq = nn.Linear(args.hidden_size, args.hidden_size, bias=False) - self.wk = nn.Linear(args.hidden_size, args.hidden_size, bias=False) - self.wv = nn.Linear(args.hidden_size, args.hidden_size, bias=False) - self.wo = nn.Linear(args.hidden_size, args.hidden_size, bias=False) + self.qkv_proj = QKVParallelLinear( + hidden_size=args.hidden_size, + head_size=self.head_dim, + total_num_heads=args.num_attention_heads, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wqkv", + disable_tp=disable_tp, + ) + self.o_proj = RowParallelLinear( + input_size=args.hidden_size, + output_size=args.hidden_size, + bias=False, + quant_config=quant_config, + prefix=f"{prefix}.wo", + disable_tp=disable_tp, + ) + + tp_size = 1 if disable_tp else get_tensor_model_parallel_world_size() + self.n_heads = divide(args.num_attention_heads, tp_size) def forward( self, @@ -646,7 +733,8 @@ def forward( ) -> torch.Tensor: batch, patches, _ = x.shape - q, k, v = self.wq(x), self.wk(x), self.wv(x) + qkv, _ = self.qkv_proj(x) + q, k, v = qkv.chunk(3, dim=-1) q = q.reshape(batch, patches, self.n_heads, self.head_dim) k = k.reshape(batch, patches, self.n_heads, self.head_dim) v = v.reshape(batch, patches, self.n_heads, self.head_dim) @@ -663,14 +751,32 @@ def forward( out = out.transpose(1, 2) out = out.reshape(batch, patches, self.n_heads * self.head_dim) - return self.wo(out) + out, _ = self.o_proj(out) + return out class TransformerBlock(nn.Module): - def __init__(self, args: VisionEncoderArgs): + def __init__( + self, + args: VisionEncoderArgs, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + disable_tp: bool = False, + ): super().__init__() - self.attention = Attention(args) - self.feed_forward = FeedForward(args) + self.attention = Attention( + args, + quant_config=quant_config, + prefix=f"{prefix}.attention", + disable_tp=disable_tp, + ) + self.feed_forward = FeedForward( + args.hidden_size, + args.intermediate_size, + quant_config=quant_config, + prefix=f"{prefix}.feed_forward", + disable_tp=disable_tp, + ) self.attention_norm = RMSNorm(args.hidden_size, eps=1e-5) self.ffn_norm = RMSNorm(args.hidden_size, eps=1e-5) @@ -690,11 +796,24 @@ def forward( class Transformer(nn.Module): - def __init__(self, args: VisionEncoderArgs): + def __init__( + self, + args: VisionEncoderArgs, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + disable_tp: bool = False, + ): super().__init__() self.layers = torch.nn.ModuleList() - for _ in range(args.num_hidden_layers): - self.layers.append(TransformerBlock(args)) + for idx in range(args.num_hidden_layers): + self.layers.append( + TransformerBlock( + args, + quant_config=quant_config, + prefix=f"{prefix}.layers.{idx}", + disable_tp=disable_tp, + ) + ) def forward( self, @@ -727,9 +846,15 @@ def position_meshgrid( class VisionTransformer(nn.Module): - def __init__(self, args: VisionEncoderArgs): + def __init__( + self, + args: VisionEncoderArgs, + quant_config: QuantizationConfig | None = None, + prefix: str = "", + ): super().__init__() self.args = args + disable_tp = is_vit_use_data_parallel() self.patch_conv = Conv2dLayer( in_channels=args.num_channels, out_channels=args.hidden_size, @@ -738,7 +863,12 @@ def __init__(self, args: VisionEncoderArgs): bias=False, ) self.ln_pre = RMSNorm(args.hidden_size, eps=1e-5) - self.transformer = Transformer(args) + self.transformer = Transformer( + args, + quant_config=quant_config, + prefix=f"{prefix}.transformer", + disable_tp=disable_tp, + ) head_dim = self.args.hidden_size // self.args.num_attention_heads assert head_dim % 2 == 0, "ROPE requires even head_dim" @@ -822,13 +952,16 @@ class VisionLanguageAdapter(nn.Module): def __init__(self, args: VisionEncoderArgs, dim: int): super().__init__() assert isinstance(args, VisionEncoderArgs) - self.w_in = nn.Linear( + self.w_in = ReplicatedLinear( args.hidden_size, dim, bias=args.adapter_bias, + return_bias=False, ) self.gelu = nn.GELU() - self.w_out = nn.Linear(dim, dim, bias=args.adapter_bias) + self.w_out = ReplicatedLinear( + dim, dim, bias=args.adapter_bias, return_bias=False + ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.w_out(self.gelu(self.w_in(x))) @@ -852,10 +985,8 @@ def __init__( self.spatial_merge_size = spatial_merge_size self.mlp_input_dim = mlp_input_dim - self.merging_layer = nn.Linear( - mlp_input_dim, - vision_encoder_dim, - bias=use_mlp_bias, + self.merging_layer = ReplicatedLinear( + mlp_input_dim, vision_encoder_dim, bias=use_mlp_bias, return_bias=False ) def forward( From 57861ae48d3493fa48b4d7d830b7ec9f995783e7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Juan=20P=C3=A9rez=20de=20Algaba?= <124347725+jperezdealgaba@users.noreply.github.com> Date: Mon, 30 Mar 2026 09:10:01 +0200 Subject: [PATCH 0010/1601] (security) Fix SSRF in batch runner download_bytes_from_url (#38482) Signed-off-by: jperezde --- docs/usage/security.md | 4 + tests/entrypoints/openai/test_run_batch.py | 134 ++++++++++++++++++++- vllm/entrypoints/openai/run_batch.py | 53 ++++++-- 3 files changed, 183 insertions(+), 8 deletions(-) diff --git a/docs/usage/security.md b/docs/usage/security.md index b126d2a1ec7f..4879ddbf64ef 100644 --- a/docs/usage/security.md +++ b/docs/usage/security.md @@ -66,6 +66,10 @@ Restrict domains that vLLM can access for media URLs by setting `--allowed-media-domains` to prevent Server-Side Request Forgery (SSRF) attacks. (e.g. `--allowed-media-domains upload.wikimedia.org github.com www.bogotobogo.com`) +This protection applies to both the online serving API (multimodal inputs) and +the **batch runner** (`vllm run-batch`), where `file_url` values in batch +transcription/translation requests are validated against the same allowlist. + Without domain restrictions, a malicious user could supply URLs that: - **Target internal services**: Access internal network endpoints, cloud metadata diff --git a/tests/entrypoints/openai/test_run_batch.py b/tests/entrypoints/openai/test_run_batch.py index bf670105bbc4..cd1daf0bbbc2 100644 --- a/tests/entrypoints/openai/test_run_batch.py +++ b/tests/entrypoints/openai/test_run_batch.py @@ -4,11 +4,15 @@ import json import subprocess import tempfile +from unittest.mock import AsyncMock, MagicMock, patch import pytest from vllm.assets.audio import AudioAsset -from vllm.entrypoints.openai.run_batch import BatchRequestOutput +from vllm.entrypoints.openai.run_batch import ( + BatchRequestOutput, + download_bytes_from_url, +) CHAT_MODEL_NAME = "hmellor/tiny-random-LlamaForCausalLM" EMBEDDING_MODEL_NAME = "intfloat/multilingual-e5-small" @@ -746,3 +750,131 @@ def test_tool_calling(): assert "arguments" in tool_call["function"] # Verify the tool name matches our tool definition assert tool_call["function"]["name"] == "get_current_weather" + + +# --------------------------------------------------------------------------- +# Unit tests for download_bytes_from_url SSRF protection +# --------------------------------------------------------------------------- + + +def _make_aiohttp_mocks(response_data: bytes = b"fake-data", status: int = 200): + """Create mock objects that simulate aiohttp.ClientSession context managers.""" + mock_resp = MagicMock() + mock_resp.status = status + mock_resp.read = AsyncMock(return_value=response_data) + mock_resp.__aenter__ = AsyncMock(return_value=mock_resp) + mock_resp.__aexit__ = AsyncMock(return_value=False) + + mock_session = MagicMock() + mock_session.get = MagicMock(return_value=mock_resp) + mock_session.__aenter__ = AsyncMock(return_value=mock_session) + mock_session.__aexit__ = AsyncMock(return_value=False) + return mock_session + + +@pytest.mark.asyncio +async def test_download_bytes_data_url_bypasses_domain_check(): + """data: URLs must work regardless of the domain allowlist.""" + data_url = f"data:audio/wav;base64,{MINIMAL_WAV_BASE64}" + result = await download_bytes_from_url( + data_url, allowed_media_domains=["example.com"] + ) + assert isinstance(result, bytes) + assert len(result) > 0 + + +@pytest.mark.asyncio +async def test_download_bytes_rejects_disallowed_domain(): + """HTTP URLs whose hostname is not in the allowlist must be rejected.""" + url = "https://evil.internal/secret" + with pytest.raises(ValueError, match="allowed domains"): + await download_bytes_from_url(url, allowed_media_domains=["example.com"]) + + +@pytest.mark.asyncio +async def test_download_bytes_rejects_cloud_metadata_ip(): + """Cloud metadata endpoints must be blocked when an allowlist is set.""" + url = "http://169.254.169.254/latest/meta-data/" + with pytest.raises(ValueError, match="allowed domains"): + await download_bytes_from_url(url, allowed_media_domains=["example.com"]) + + +@pytest.mark.asyncio +async def test_download_bytes_rejects_internal_ip(): + """Private-range IPs must be blocked when an allowlist is set.""" + for internal_url in [ + "http://10.0.0.1/secret", + "http://192.168.1.1/admin", + "http://127.0.0.1:8080/internal", + ]: + with pytest.raises(ValueError, match="allowed domains"): + await download_bytes_from_url( + internal_url, allowed_media_domains=["example.com"] + ) + + +@pytest.mark.asyncio +async def test_download_bytes_allows_permitted_domain(): + """HTTP URLs whose hostname IS in the allowlist must be fetched.""" + url = "https://example.com/audio.wav" + expected = b"audio-bytes" + mock_session = _make_aiohttp_mocks(expected) + + with patch( + "vllm.entrypoints.openai.run_batch.aiohttp.ClientSession", + return_value=mock_session, + ): + result = await download_bytes_from_url( + url, allowed_media_domains=["example.com"] + ) + assert result == expected + + +@pytest.mark.asyncio +async def test_download_bytes_no_allowlist_permits_any_domain(): + """Without an allowlist all HTTP URLs must be attempted (backward compat).""" + url = "https://any-domain.example.org/file.wav" + expected = b"some-data" + mock_session = _make_aiohttp_mocks(expected) + + with patch( + "vllm.entrypoints.openai.run_batch.aiohttp.ClientSession", + return_value=mock_session, + ): + result = await download_bytes_from_url(url, allowed_media_domains=None) + assert result == expected + + +@pytest.mark.asyncio +async def test_download_bytes_empty_allowlist_denies_all(): + """An empty allowlist must deny all HTTP URLs (least privilege).""" + url = "https://any-domain.example.org/file.wav" + with pytest.raises(ValueError, match="allowed domains"): + await download_bytes_from_url(url, allowed_media_domains=[]) + + +@pytest.mark.asyncio +async def test_download_bytes_unsupported_scheme(): + """Unsupported URL schemes must be rejected regardless of allowlist.""" + with pytest.raises(ValueError, match="Unsupported URL scheme"): + await download_bytes_from_url("ftp://example.com/file.wav") + + with pytest.raises(ValueError, match="Unsupported URL scheme"): + await download_bytes_from_url( + "ftp://example.com/file.wav", + allowed_media_domains=["example.com"], + ) + + +@pytest.mark.asyncio +async def test_download_bytes_backslash_bypass(): + """Backslash-@ URL confusion must not bypass the allowlist. + + urllib3.parse_url() and aiohttp/yarl disagree on backslash-before-@. + The fix normalizes through urllib3 before handing to aiohttp. + """ + bypass_url = "http://allowed.example.com\\@evil.internal/secret" + with pytest.raises(ValueError, match="allowed domains"): + await download_bytes_from_url( + bypass_url, allowed_media_domains=["evil.internal"] + ) diff --git a/vllm/entrypoints/openai/run_batch.py b/vllm/entrypoints/openai/run_batch.py index 03a15991d858..3afd9b8cab74 100644 --- a/vllm/entrypoints/openai/run_batch.py +++ b/vllm/entrypoints/openai/run_batch.py @@ -20,7 +20,9 @@ from pydantic_core.core_schema import ValidationInfo from starlette.datastructures import State from tqdm import tqdm +from urllib3.util import parse_url +import vllm.envs as envs from vllm.config import config from vllm.engine.arg_utils import AsyncEngineArgs from vllm.engine.protocol import EngineClient @@ -439,19 +441,25 @@ async def write_file( await write_local_file(path_or_url, batch_outputs) -async def download_bytes_from_url(url: str) -> bytes: +async def download_bytes_from_url( + url: str, + allowed_media_domains: list[str] | None = None, +) -> bytes: """ Download data from a URL or decode from a data URL. Args: url: Either an HTTP/HTTPS URL or a data URL (data:...;base64,...) + allowed_media_domains: If set, only HTTP/HTTPS URLs whose hostname + is in this list are permitted. data: URLs are not subject to + this restriction. Returns: Data as bytes """ parsed = urlparse(url) - # Handle data URLs (base64 encoded) + # Handle data URLs (base64 encoded) - not subject to domain restrictions if parsed.scheme == "data": # Format: data:...;base64, if "," in url: @@ -465,9 +473,24 @@ async def download_bytes_from_url(url: str) -> bytes: # Handle HTTP/HTTPS URLs elif parsed.scheme in ("http", "https"): + if allowed_media_domains is not None: + url_spec = parse_url(url) + if url_spec.hostname not in allowed_media_domains: + raise ValueError( + f"The URL must be from one of the allowed domains: " + f"{allowed_media_domains}. Input URL domain: " + f"{url_spec.hostname}" + ) + # Use the normalized URL to prevent parsing discrepancies + # between urllib3 and aiohttp (e.g. backslash-@ attacks). + url = url_spec.url + async with ( aiohttp.ClientSession() as session, - session.get(url) as resp, + session.get( + url, + allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS, + ) as resp, ): if resp.status != 200: raise Exception( @@ -593,7 +616,10 @@ def handle_endpoint_request( return run_request(handler_fn, request, tracker) -def make_transcription_wrapper(is_translation: bool) -> WrapperFn: +def make_transcription_wrapper( + is_translation: bool, + allowed_media_domains: list[str] | None = None, +) -> WrapperFn: """ Factory function to create a wrapper for transcription/translation handlers. The wrapper converts BatchTranscriptionRequest or BatchTranslationRequest @@ -602,6 +628,8 @@ def make_transcription_wrapper(is_translation: bool) -> WrapperFn: Args: is_translation: If True, process as translation; otherwise process as transcription + allowed_media_domains: If set, only URLs from these domains are + permitted for HTTP/HTTPS fetches. Returns: A function that takes a handler and returns a wrapped handler @@ -619,7 +647,10 @@ async def transcription_wrapper( ): try: # Download data from URL - audio_data = await download_bytes_from_url(batch_request_body.file_url) + audio_data = await download_bytes_from_url( + batch_request_body.file_url, + allowed_media_domains=allowed_media_domains, + ) # Create a mock file from the downloaded audio data mock_file = UploadFile( @@ -691,6 +722,8 @@ async def build_endpoint_registry( serving_embedding = getattr(state, "serving_embedding", None) serving_scores = getattr(state, "serving_scores", None) + allowed_media_domains = getattr(args, "allowed_media_domains", None) + # Registry of endpoint configurations endpoint_registry: dict[str, dict[str, Any]] = { "completions": { @@ -730,7 +763,10 @@ async def build_endpoint_registry( if openai_serving_transcription is not None else None ), - "wrapper_fn": make_transcription_wrapper(is_translation=False), + "wrapper_fn": make_transcription_wrapper( + is_translation=False, + allowed_media_domains=allowed_media_domains, + ), }, "translations": { "url_matcher": lambda url: url == "/v1/audio/translations", @@ -739,7 +775,10 @@ async def build_endpoint_registry( if openai_serving_translation is not None else None ), - "wrapper_fn": make_transcription_wrapper(is_translation=True), + "wrapper_fn": make_transcription_wrapper( + is_translation=True, + allowed_media_domains=allowed_media_domains, + ), }, } From 85c0950b1f647e0b0654fbf3e91a9757b8233752 Mon Sep 17 00:00:00 2001 From: Tan Pin Siang Date: Mon, 30 Mar 2026 15:19:33 +0800 Subject: [PATCH 0011/1601] [ROCm] Enable MORI EP for unquantized MoE with AITER backend (#37529) Signed-off-by: Tan Pin Siang --- .../layers/fused_moe/all2all_utils.py | 17 ++++++++++++----- .../fused_moe/unquantized_fused_moe_method.py | 16 ++++++++++++---- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/vllm/model_executor/layers/fused_moe/all2all_utils.py b/vllm/model_executor/layers/fused_moe/all2all_utils.py index 44c9bb79e154..534004e112f2 100644 --- a/vllm/model_executor/layers/fused_moe/all2all_utils.py +++ b/vllm/model_executor/layers/fused_moe/all2all_utils.py @@ -186,16 +186,23 @@ def maybe_make_prepare_finalize( use_fp8_dispatch = ( quant_config.is_per_act_token or quant_config.is_block_quantized ) - # For PTPC (per token per channel) quant, the scale dim for each token is 1 - # For 1x128 quant, the scale dim for each token is hidden_dim // 128 - scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128 + if use_fp8_dispatch: + # For PTPC (per token per channel) quant, scale dim is 1 + # For 1x128 quant, scale dim is hidden_dim // 128 + quant_dtype = quant_config.quant_dtype + scale_dim = 1 if quant_config.is_per_act_token else moe.hidden_dim // 128 + else: + # Unquantized dispatch (e.g. AITER with defer_input_quant): + # dispatch raw BF16/FP16 data, no scales needed. + quant_dtype = moe.in_dtype + scale_dim = 0 all_to_all_args = dict( rank=all2all_manager.rank, num_ep_ranks=all2all_manager.world_size, - quant_dtype=quant_config.quant_dtype, + quant_dtype=quant_dtype, token_hidden_size=moe.hidden_dim, scale_dim=scale_dim, - scale_type_size=torch.float32.itemsize, + scale_type_size=0 if scale_dim == 0 else torch.float32.itemsize, max_num_tokens_per_dp_rank=moe.max_num_tokens, input_dtype=moe.in_dtype, num_local_experts=moe.num_experts // all2all_manager.world_size, diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index a29d8a7d8dda..38b552b021ba 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -108,10 +108,7 @@ def maybe_make_prepare_finalize( self, routing_tables: tuple[torch.Tensor, torch.Tensor, torch.Tensor] | None = None, ) -> FusedMoEPrepareAndFinalizeModular | None: - if self.unquantized_backend == UnquantizedMoeBackend.AITER: - return None - else: - return super().maybe_make_prepare_finalize(routing_tables) + return super().maybe_make_prepare_finalize(routing_tables) def select_gemm_impl( self, @@ -130,6 +127,17 @@ def select_gemm_impl( max_num_tokens=self.moe.max_num_tokens, num_dispatchers=prepare_finalize.num_dispatchers(), ) + elif ( + self.unquantized_backend == UnquantizedMoeBackend.AITER + and rocm_aiter_ops.is_fused_moe_enabled() + ): + from .rocm_aiter_fused_moe import AiterExperts + + logger.debug("AiterExperts %s", self.moe) + return AiterExperts( + moe_config=self.moe, + quant_config=self.moe_quant_config, + ) else: logger.debug("TritonExperts %s", self.moe) return TritonExperts( From a08b7733fd3c07ad7b28e15a3fd3e75a503989ec Mon Sep 17 00:00:00 2001 From: haosdent Date: Mon, 30 Mar 2026 15:48:33 +0800 Subject: [PATCH 0012/1601] [CI] Fix SPLADE pooler test broken by #38139 (#38495) Signed-off-by: haosdent --- .../language/pooling/test_splade_sparse_pooler.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/tests/models/language/pooling/test_splade_sparse_pooler.py b/tests/models/language/pooling/test_splade_sparse_pooler.py index af4fd764ef53..38a90d07abeb 100644 --- a/tests/models/language/pooling/test_splade_sparse_pooler.py +++ b/tests/models/language/pooling/test_splade_sparse_pooler.py @@ -1,8 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -import types - import pytest import torch import torch.nn as nn @@ -11,6 +9,8 @@ BertMLMHead, SPLADESparsePooler, ) +from vllm.pooling_params import PoolingParams +from vllm.v1.pool.metadata import PoolingMetadata, PoolingStates # --------------------------------------------------------------------- # Functional test: SPLADE formula correctness (no HF download needed) @@ -38,8 +38,12 @@ def test_splade_pooler_matches_reference_formula(B, T, H, V): ], dtype=torch.long, ) - meta = types.SimpleNamespace( - prompt_lens=prompt_lens_tenser, prompt_token_ids=token_ids + meta = PoolingMetadata( + prompt_lens=prompt_lens_tenser, + prompt_token_ids=token_ids, + prompt_token_ids_cpu=token_ids, + pooling_params=[PoolingParams(task="embed")] * B, + pooling_states=[PoolingStates() for _ in range(B)], ) # MLM head (prefer BertMLMHead, fallback to Linear if unavailable) From 03ac6ca8954d491dc39ae169c2623e8ccffba7c6 Mon Sep 17 00:00:00 2001 From: TJian Date: Mon, 30 Mar 2026 17:25:46 +0800 Subject: [PATCH 0013/1601] [ROCm] [DOC] Update the Documentation to include ROCm Nightly Wheel support (#38457) Signed-off-by: tjtanaa --- .../installation/gpu.rocm.inc.md | 125 +++++++++++++++++- 1 file changed, 118 insertions(+), 7 deletions(-) diff --git a/docs/getting_started/installation/gpu.rocm.inc.md b/docs/getting_started/installation/gpu.rocm.inc.md index 101ab9d56119..4ab01ee8c687 100644 --- a/docs/getting_started/installation/gpu.rocm.inc.md +++ b/docs/getting_started/installation/gpu.rocm.inc.md @@ -1,7 +1,14 @@ --8<-- [start:installation] -vLLM supports AMD GPUs with ROCm 6.3 or above. Pre-built wheels are available for ROCm 7.0. +vLLM supports AMD GPUs with ROCm 6.3 or above. Pre-built wheels are available for ROCm 7.0 and ROCm 7.2.1. + +#### Prebuilt Wheels + +| ROCm Variant | Python Version | ROCm Version | glibc Requirement | Supported Versions | +| ------------ | -------------- | ------------ | ----------------- | ------------------ | +| `rocm700` | 3.12 | 7.0 | >= 2.35 | `0.14.0` to `0.18.0` | +| `rocm721` | 3.12 | 7.2.1 | >= 2.35 | Nightly releases after commit `171775f306a333a9cf105bfd533bf3e113d401d9` | --8<-- [end:installation] --8<-- [start:requirements] @@ -23,26 +30,112 @@ If you need a different ROCm version or want to use an existing PyTorch installa To install the latest version of vLLM for Python 3.12, ROCm 7.0 and `glibc >= 2.35`. ```bash -uv pip install vllm --extra-index-url https://wheels.vllm.ai/rocm/ +uv pip install vllm --extra-index-url https://wheels.vllm.ai/rocm/ --upgrade ``` !!! tip - You can find out about which ROCm version the latest vLLM supports by checking the index in extra-index-url [https://wheels.vllm.ai/rocm/](https://wheels.vllm.ai/rocm/) . + You can find out about which ROCm version the latest vLLM supports by checking the `vllm` package in index in extra-index-url at [https://wheels.vllm.ai/rocm/vllm](https://wheels.vllm.ai/rocm/vllm) . + + Another approach is that you can use this following commands to automatically extract the wheel variants: + + ```bash + # automatically extract the available rocm variant + export VLLM_ROCM_VARIANT=$(curl -s https://wheels.vllm.ai/rocm/vllm | grep -oP 'rocm\d+' | head -1) + + # automatically extract the vLLM version + export VLLM_VERSION=$(curl -s https://wheels.vllm.ai/rocm/vllm | grep -oP 'vllm-\K[0-9.]+' | head -1) + + # inspect if the ROCm version is compatible with your environment + echo $VLLM_ROCM_VARIANT + echo $VLLM_VERSION + ``` To install a specific version and ROCm variant of vLLM wheel. ```bash -uv pip install vllm --extra-index-url https://wheels.vllm.ai/rocm/0.15.0/rocm700 +# version without the `v` +uv pip install vllm==${VLLM_VERSION} --extra-index-url https://wheels.vllm.ai/rocm/${VLLM_VERSION}/${VLLM_ROCM_VARIANT} + +# Example +uv pip install vllm==0.18.0 --extra-index-url https://wheels.vllm.ai/rocm/0.18.0/rocm700 ``` !!! warning "Caveats for using `pip`" - We recommend leveraging `uv` to install vLLM wheel. Using `pip` to install from custom indices is cumbersome, because `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install wheel from custom index if exact versions of all packages are specified exactly. In contrast, `uv` gives the extra index [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). + We recommend leveraging `uv` to install the vLLM wheel. Using `pip` to install from custom indices is cumbersome because `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version. This makes it difficult to install a wheel from a custom index unless exact versions of all packages are specified. In contrast, `uv` gives the extra index [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). - If you insist on using `pip`, you have to specify the exact vLLM version and full URL of the wheel path `https://wheels.vllm.ai/rocm//` (which can be obtained from the web page). + If you insist on using `pip`, you need to specify the exact vLLM version in the package name and provide the custom index URL `https://wheels.vllm.ai/rocm/${VLLM_VERSION}/${VLLM_ROCM_VARIANT}` via `--extra-index-url`. ```bash - pip install vllm==0.15.0+rocm700 --extra-index-url https://wheels.vllm.ai/rocm/0.15.0/rocm700 + pip install vllm==0.18.0+rocm700 --extra-index-url https://wheels.vllm.ai/rocm/0.18.0/rocm700 + ``` + +#### Install the latest code + +LLM inference is a fast-evolving field, and the latest code may contain bug fixes, performance improvements, and new features that are not released yet. To allow users to try the latest code without waiting for the next release, vLLM provides wheels for every commit since commit `171775f306a333a9cf105bfd533bf3e113d401d9` on . The custom index to be used is `https://wheels.vllm.ai/rocm/nightly/${VLLM_ROCM_VARIANT}` + +**NOTE:** The first ROCm Variant that supports nightly wheel is ROCm 7.2.1 + +To install from latest nightly index, run: + +```bash +# automatically extract the available rocm variant +export VLLM_ROCM_VARIANT=$(curl -s https://wheels.vllm.ai/rocm/nightly | \ + grep -oP 'rocm\d+' | head -1 | sed 's/%2B/+/g') + +# inspect if the ROCm version is compatible with your environment +echo $VLLM_ROCM_VARIANT + +uv pip install --pre vllm \ + --extra-index-url https://wheels.vllm.ai/rocm/nightly/${VLLM_ROCM_VARIANT} \ + --index-strategy unsafe-best-match +``` + +##### Install specific revisions + +If you want to access the wheels for previous commits (e.g. to bisect the behavior change, performance regression), you can specify the commit hash in the URL, example: + +```bash +export VLLM_COMMIT=5b8c30d62b754b575e043ce2fc0dcbf8a64f6306 + +export VLLM_ROCM_VARIANT=$(curl -s https://wheels.vllm.ai/rocm/${VLLM_COMMIT} | \ + grep -oP 'rocm\d+' | head -1 | sed 's/%2B/+/g') + +# Extract the version from the wheel URL +export VLLM_VERSION=$(curl -s https://wheels.vllm.ai/rocm/${VLLM_COMMIT}/${VLLM_ROCM_VARIANT}/vllm/ | \ + grep -oP 'vllm-\K[^-]+' | head -1 | sed 's/%2B/+/g') + +# inspect the version if it is compatible with the ROCm version of your environment +echo $VLLM_ROCM_VARIANT +echo $VLLM_VERSION + +uv pip install vllm==${VLLM_VERSION} \ + --extra-index-url https://wheels.vllm.ai/rocm/${VLLM_COMMIT}/${VLLM_ROCM_VARIANT} \ + --index-strategy unsafe-best-match +``` + +!!! warning "`pip` caveat" + + Using `pip` to install from nightly indices is _not supported_, because `pip` combines packages from `--extra-index-url` and the default index, choosing only the latest version, which makes it difficult to install a development version prior to the released version. In contrast, `uv` gives the extra index [higher priority than the default index](https://docs.astral.sh/uv/pip/compatibility/#packages-that-exist-on-multiple-indexes). + + If you insist on using `pip`, you need to specify the exact vLLM version in the package name and provide the custom index URL (which can be obtained from the web page). + + ```bash + export VLLM_COMMIT=5b8c30d62b754b575e043ce2fc0dcbf8a64f6306 + + export VLLM_ROCM_VARIANT=$(curl -s https://wheels.vllm.ai/rocm/${VLLM_COMMIT} | \ + grep -oP 'rocm\d+' | head -1 | sed 's/%2B/+/g') + + # Extract the version from the wheel URL + export VLLM_VERSION=$(curl -s https://wheels.vllm.ai/rocm/${VLLM_COMMIT}/${VLLM_ROCM_VARIANT}/vllm/ | \ + grep -oP 'vllm-\K[^-]+' | head -1 | sed 's/%2B/+/g') + + # inspect the version if it is compatible with the ROCm version of your environment + echo $VLLM_ROCM_VARIANT + echo $VLLM_VERSION + + pip install vllm==${VLLM_VERSION} \ + --extra-index-url https://wheels.vllm.ai/rocm/${VLLM_COMMIT}/${VLLM_ROCM_VARIANT} ``` --8<-- [end:pre-built-wheels] @@ -193,6 +286,24 @@ docker run --rm \ --model Qwen/Qwen3-0.6B ``` +To use the docker image as base for development, you can launch it in interactive session through overriding the entrypoint. + +???+ console "Commands" + ```bash + docker run --rm -it \ + --group-add=video \ + --cap-add=SYS_PTRACE \ + --security-opt seccomp=unconfined \ + --device /dev/kfd \ + --device /dev/dri \ + -v ~/.cache/huggingface:/root/.cache/huggingface \ + --env "HF_TOKEN=$HF_TOKEN" \ + --network=host \ + --ipc=host \ + --entrypoint /bin/bash \ + vllm/vllm-openai-rocm: + ``` + #### Use AMD's Docker Images (Deprecated) !!! warning "Deprecated" From cc06b4e86b2beb04fbee3e6d9167cc97f1491b1f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Nicol=C3=B2=20Lucchesi?= Date: Mon, 30 Mar 2026 11:41:50 +0200 Subject: [PATCH 0014/1601] [Mamba][Bugfix] Raise on insufficient cache blocks instead of silently capping cudagraph sizes (#38270) Signed-off-by: NickLucche --- tests/compile/test_config.py | 42 ------------------- tests/v1/worker/test_gpu_model_runner.py | 29 ++++--------- vllm/config/compilation.py | 52 ------------------------ vllm/v1/worker/gpu_model_runner.py | 49 +++++++++++++++------- 4 files changed, 41 insertions(+), 131 deletions(-) diff --git a/tests/compile/test_config.py b/tests/compile/test_config.py index 53434b0b4c68..43fe49037c2f 100644 --- a/tests/compile/test_config.py +++ b/tests/compile/test_config.py @@ -577,48 +577,6 @@ def test_compile_sizes_padding_validation(): dispatcher.initialize_cudagraph_keys(CUDAGraphMode.NONE) # Should not raise -@pytest.mark.parametrize( - "capture_sizes, max_size, num_blocks, expected_sizes, expected_max", - [ - # Normal capping: sizes filtered to <= num_blocks - ( - [1, 2, 4, 8, 16, 32, 64, 128, 256, 512], - 512, - 200, - [1, 2, 4, 8, 16, 32, 64, 128], - 128, - ), - # No capping needed: num_blocks >= max - ([1, 2, 4, 8, 16], 16, 1000, [1, 2, 4, 8, 16], 16), - # Exact boundary: num_blocks == max (no capping) - ([1, 2, 4, 8, 16, 32], 32, 32, [1, 2, 4, 8, 16, 32], 32), - # All sizes capped: num_blocks < smallest size - ([8, 16, 32], 32, 4, [], 0), - # num_blocks <= 0: early return, no change - ([1, 2, 4], 4, 0, [1, 2, 4], 4), - ], -) -def test_adjust_cudagraph_sizes_for_mamba_cache( - capture_sizes, max_size, num_blocks, expected_sizes, expected_max -): - """Test that cudagraph capture sizes are correctly capped to fit - available Mamba cache blocks. - - See: https://github.com/vllm-project/vllm/issues/34094 - """ - config = CompilationConfig( - cudagraph_capture_sizes=capture_sizes, - max_cudagraph_capture_size=max_size, - cudagraph_mode=CUDAGraphMode.NONE, - ) - config.adjust_cudagraph_sizes_for_mamba_cache(num_blocks) - assert config.cudagraph_capture_sizes == expected_sizes - assert config.max_cudagraph_capture_size == expected_max - # Invariant: last element == max_cudagraph_capture_size - if expected_sizes: - assert config.cudagraph_capture_sizes[-1] == config.max_cudagraph_capture_size - - def test_inductor_asserts_default_disabled(monkeypatch): """Test that inductor runtime asserts are disabled by default (INFO logging level) on torch < 2.12.""" diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 93c5435e817b..02fac6dfc142 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1191,9 +1191,9 @@ def test_is_uniform_decode() -> None: current_platform.is_rocm(), reason="Attention backend FLASHINFER is not supported on ROCm.", ) -def test_cudagraph_sizes_capped_for_mamba_cache(): - """Test that cudagraph capture sizes are capped to num_blocks for - hybrid models with Mamba layers. +def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks(): + """Test that a ValueError is raised when max_num_seqs exceeds the + available Mamba cache blocks for hybrid models with FULL cudagraphs. See: https://github.com/vllm-project/vllm/issues/34094 """ @@ -1284,23 +1284,8 @@ def test_cudagraph_sizes_capped_for_mamba_cache(): )[0] num_blocks = kv_cache_config.num_blocks - # Set max_cudagraph_capture_size to a value larger than num_blocks - # to trigger the Mamba capping logic. - large_max = num_blocks + 100 - compilation_config = vllm_config.compilation_config - compilation_config.max_cudagraph_capture_size = large_max - compilation_config.cudagraph_capture_sizes = [ - s for s in [1, 2, 4, 8, 16, 32, 64, 128, 256, 512] if s <= large_max - ] + # Force max_num_seqs to exceed num_blocks so the check triggers. + runner.max_num_reqs = num_blocks + 100 - runner.initialize_kv_cache(kv_cache_config) - - # After initialization, cudagraph sizes should be capped - assert compilation_config.max_cudagraph_capture_size <= num_blocks - assert all(s <= num_blocks for s in compilation_config.cudagraph_capture_sizes) - # Invariant: last element == max - if compilation_config.cudagraph_capture_sizes: - assert ( - compilation_config.cudagraph_capture_sizes[-1] - == compilation_config.max_cudagraph_capture_size - ) + with pytest.raises(ValueError, match="max_num_seqs"): + runner.initialize_kv_cache(kv_cache_config) diff --git a/vllm/config/compilation.py b/vllm/config/compilation.py index 5b6648908dd6..1d09e2b7de70 100644 --- a/vllm/config/compilation.py +++ b/vllm/config/compilation.py @@ -1279,58 +1279,6 @@ def adjust_cudagraph_sizes_for_spec_decode( self.max_cudagraph_capture_size = rounded_sizes[-1] self.cudagraph_capture_sizes = rounded_sizes - def adjust_cudagraph_sizes_for_mamba_cache( - self, num_mamba_cache_blocks: int - ) -> None: - """Cap cudagraph capture sizes to available Mamba cache blocks. - - For hybrid Mamba/attention models, the Mamba conv_state and - ssm_state tensors have their first dimension equal to num_blocks - (from KVCacheConfig). During CUDA graph capture the decode batch - size equals num_tokens, so capture sizes exceeding num_blocks - would cause out-of-bounds access in Mamba kernels. - - See: https://github.com/vllm-project/vllm/issues/34094 - """ - if not self.cudagraph_capture_sizes or num_mamba_cache_blocks <= 0: - return - - assert self.max_cudagraph_capture_size is not None - - if num_mamba_cache_blocks >= self.max_cudagraph_capture_size: - return - - capped_sizes = [ - s for s in self.cudagraph_capture_sizes if s <= num_mamba_cache_blocks - ] - - if len(capped_sizes) == 0: - logger.warning( - "No valid cudagraph capture sizes remain after capping " - "to Mamba cache blocks (%d). The smallest capture size " - "was %d. Disabling cudagraph capture. Consider reducing " - "max_num_seqs or increasing available GPU memory.", - num_mamba_cache_blocks, - self.cudagraph_capture_sizes[0], - ) - self.cudagraph_capture_sizes = [] - self.max_cudagraph_capture_size = 0 - return - - logger.warning( - "Capping cudagraph capture sizes from max %d to %d to fit " - "Mamba cache blocks (%d blocks available). This limits the " - "maximum batch size that can use CUDA graphs. To increase " - "this limit, reduce max_num_seqs or increase available GPU " - "memory.", - self.max_cudagraph_capture_size, - capped_sizes[-1], - num_mamba_cache_blocks, - ) - - self.max_cudagraph_capture_size = capped_sizes[-1] - self.cudagraph_capture_sizes = capped_sizes - def get_compile_ranges(self) -> list[Range]: """Get the compile ranges for the compilation config.""" if self.compile_ranges_endpoints is None: diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8a43f43d0398..8cfa61baa599 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -5800,7 +5800,7 @@ def _init_minimal_kv_cache_for_profiling(self) -> None: ) self.cache_config.num_gpu_blocks_override = saved_override - self.initialize_kv_cache(minimal_config) + self.initialize_kv_cache(minimal_config, is_profiling=True) self.cache_config.num_gpu_blocks = minimal_config.num_blocks logger.debug("Initialized minimal KV cache for CUDA graph profiling") @@ -6121,7 +6121,11 @@ def _capture_cudagraphs( torch.accelerator.synchronize() self.maybe_remove_all_loras(self.lora_config) - def initialize_attn_backend(self, kv_cache_config: KVCacheConfig) -> None: + def initialize_attn_backend( + self, + kv_cache_config: KVCacheConfig, + is_profiling: bool = False, + ) -> None: """ Initialize the attention backends and attention metadata builders. """ @@ -6193,7 +6197,9 @@ def create_attn_groups( # Resolve cudagraph_mode before actually initialize metadata_builders self._check_and_update_cudagraph_mode( - attention_backend_list, kv_cache_config.kv_cache_groups + attention_backend_list, + kv_cache_config.kv_cache_groups, + is_profiling=is_profiling, ) # Check if attention backend supports PCP&DCP and related features. @@ -6237,6 +6243,7 @@ def _check_and_update_cudagraph_mode( self, attention_backends: list[set[type[AttentionBackend]]], kv_cache_groups: list[KVCacheGroupSpec], + is_profiling: bool = False, ) -> None: """ Resolve the cudagraph_mode when there are multiple attention @@ -6377,21 +6384,29 @@ def _check_and_update_cudagraph_mode( self.uniform_decode_query_len, self.parallel_config.tensor_parallel_size ) - # If the model has Mamba layers and cudagraph mode includes FULL - # decode, cap cudagraph capture sizes to the number of available - # Mamba cache blocks. Each decode request needs one conv_state - # cache line, so capture batch sizes cannot exceed num_blocks. - # Only FULL decode graphs are affected because PIECEWISE captures - # run GDN/Mamba ops eagerly (prefill path, no causal_conv1d_update). + # For Mamba models with FULL decode cudagraphs, each decode + # sequence needs one Mamba cache block. The decode cudagraph + # dispatcher already caps batch sizes at max_num_seqs, so we just + # need to verify that enough blocks exist. Raising here instead + # of silently capping cudagraph_capture_sizes avoids unintended + # restrictions on PIECEWISE (prefill) cudagraphs. # See: https://github.com/vllm-project/vllm/issues/34094 - if cudagraph_mode.has_full_cudagraphs(): + if cudagraph_mode.has_full_cudagraphs() and not is_profiling: has_mamba = any( isinstance(g.kv_cache_spec, MambaSpec) for g in kv_cache_groups ) if has_mamba and self.kv_cache_config is not None: - self.compilation_config.adjust_cudagraph_sizes_for_mamba_cache( - self.kv_cache_config.num_blocks - ) + num_blocks = self.kv_cache_config.num_blocks + if self.max_num_reqs > num_blocks: + raise ValueError( + f"max_num_seqs ({self.max_num_reqs}) exceeds " + f"available Mamba cache blocks ({num_blocks}). " + f"Each decode sequence requires one Mamba cache " + f"block, so CUDA graph capture cannot proceed. " + f"Please lower max_num_seqs to at most " + f"{num_blocks} or increase " + f"gpu_memory_utilization." + ) # Trigger cudagraph dispatching keys initialization after # resolved cudagraph mode. @@ -6752,7 +6767,11 @@ def maybe_add_kv_sharing_layers_to_kv_cache_groups( else: break - def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: + def initialize_kv_cache( + self, + kv_cache_config: KVCacheConfig, + is_profiling: bool = False, + ) -> None: """ Initialize KV cache based on `kv_cache_config`. Args: @@ -6764,7 +6783,7 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None: self._mamba_copy_bufs = None self.may_add_encoder_only_layers_to_kv_cache_config() self.maybe_add_kv_sharing_layers_to_kv_cache_groups(kv_cache_config) - self.initialize_attn_backend(kv_cache_config) + self.initialize_attn_backend(kv_cache_config, is_profiling=is_profiling) # The kernel block size for all KV cache groups. For example, if # kv_cache_manager uses block_size 256 for a given group, but the attention # backends for that group only supports block_size 64, we will return From 3683fe6c0651fe54a0201552ae7dfb7acb1e0cea Mon Sep 17 00:00:00 2001 From: yzong-rh Date: Mon, 30 Mar 2026 06:12:13 -0400 Subject: [PATCH 0015/1601] [Bugfix] Fix shared-object aliasing in n>1 streaming with tool calls (#38158) Signed-off-by: Yifan Zong Signed-off-by: Yifan Co-authored-by: Chauncey --- .../chat_completion/test_serving_chat.py | 165 ++++++++++++++++++ .../openai/chat_completion/serving.py | 5 +- 2 files changed, 168 insertions(+), 2 deletions(-) diff --git a/tests/entrypoints/openai/chat_completion/test_serving_chat.py b/tests/entrypoints/openai/chat_completion/test_serving_chat.py index 89bb8a1fe9c6..cb356e0e1986 100644 --- a/tests/entrypoints/openai/chat_completion/test_serving_chat.py +++ b/tests/entrypoints/openai/chat_completion/test_serving_chat.py @@ -1,6 +1,7 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio +import json from contextlib import suppress from dataclasses import dataclass, field from typing import Any @@ -1792,6 +1793,170 @@ async def test_tool_choice_validation_without_parser(): assert "--tool-call-parser" in response_named.error.message +@pytest.mark.asyncio +async def test_streaming_n_gt1_independent_tool_parsers(): + """n>1 streaming must use independent parser instances + and token-id histories per choice. + """ + mock_engine = MagicMock(spec=AsyncLLM) + mock_engine.errored = False + mock_engine.model_config = MockModelConfig() + mock_engine.input_processor = MagicMock() + mock_engine.io_processor = MagicMock() + mock_engine.renderer = _build_renderer(mock_engine.model_config) + + models = OpenAIServingModels( + engine_client=mock_engine, + base_model_paths=BASE_MODEL_PATHS, + ) + openai_serving_render = _build_serving_render(mock_engine, models.registry) + + serving_chat = OpenAIServingChat( + mock_engine, + models, + response_role="assistant", + openai_serving_render=openai_serving_render, + chat_template=CHAT_TEMPLATE, + chat_template_content_format="auto", + request_logger=None, + enable_auto_tools=True, + tool_parser="hermes", + ) + + tokenizer = get_tokenizer(MODEL_NAME) + + tools = [ + { + "type": "function", + "function": { + "name": "get_weather", + "description": "Get weather", + "parameters": { + "type": "object", + "properties": {"city": {"type": "string"}}, + "required": ["city"], + }, + }, + } + ] + + num_choices = 2 + + request = ChatCompletionRequest( + model=MODEL_NAME, + messages=[{"role": "user", "content": "test"}], + n=num_choices, + stream=True, + tools=tools, + tool_choice="auto", + ) + + tool_call_text = ( + "\n" + '{"name": "get_weather", "arguments": {"city": "Tokyo"}}\n' + "" + ) + all_token_ids = tokenizer.encode(tool_call_text, add_special_tokens=False) + + # Compute proper delta text for each token so that concatenated deltas + # reproduce the original string exactly. + steps: list[tuple[str, int]] = [] + prev_decoded = "" + for i, tid in enumerate(all_token_ids): + decoded_so_far = tokenizer.decode(all_token_ids[: i + 1]) + delta = decoded_so_far[len(prev_decoded) :] + steps.append((delta, tid)) + prev_decoded = decoded_so_far + + async def result_generator(): + for delta_text, token_id in steps: + yield RequestOutput( + request_id="test-req", + prompt="test", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[ + CompletionOutput( + index=choice_idx, + text=delta_text, + token_ids=[token_id], + cumulative_logprob=0.0, + logprobs=None, + ) + for choice_idx in range(num_choices) + ], + finished=False, + ) + # Final output with finish_reason + yield RequestOutput( + request_id="test-req", + prompt="test", + prompt_token_ids=[1, 2, 3], + prompt_logprobs=None, + outputs=[ + CompletionOutput( + index=choice_idx, + text="", + token_ids=[], + cumulative_logprob=0.0, + logprobs=None, + finish_reason="stop", + ) + for choice_idx in range(num_choices) + ], + finished=True, + ) + + # Collect tool-call deltas per choice from the SSE stream. + tc_deltas_by_choice: dict[int, list[dict]] = {i: [] for i in range(num_choices)} + async for chunk_str in serving_chat.chat_completion_stream_generator( + request=request, + result_generator=result_generator(), + request_id="test-req", + model_name=MODEL_NAME, + conversation=[], + tokenizer=tokenizer, + request_metadata=RequestResponseMetadata( + request_id="test-req", + model_name=MODEL_NAME, + ), + ): + if not chunk_str.strip() or "data: [DONE]" in chunk_str: + continue + if chunk_str.startswith("data: "): + data = json.loads(chunk_str[6:].strip()) + for choice in data.get("choices", []): + idx = choice["index"] + delta = choice.get("delta", {}) + if delta.get("tool_calls"): + for tc in delta["tool_calls"]: + tc_deltas_by_choice[idx].append(tc) + + # Both choices must independently produce the correct tool call. + for choice_idx in range(num_choices): + deltas = tc_deltas_by_choice[choice_idx] + assert len(deltas) > 0, ( + f"Choice {choice_idx}: expected tool-call deltas but got none" + ) + + name = None + args_buf = "" + for tc in deltas: + fn = tc.get("function", {}) + if fn.get("name"): + name = fn["name"] + if fn.get("arguments"): + args_buf += fn["arguments"] + + assert name == "get_weather", ( + f"Choice {choice_idx}: expected 'get_weather', got {name!r}" + ) + parsed_args = json.loads(args_buf) + assert parsed_args == {"city": "Tokyo"}, ( + f"Choice {choice_idx}: expected {{'city': 'Tokyo'}}, got {parsed_args}" + ) + + class TestCreateRemainingArgsDelta: """Tests for _create_remaining_args_delta helper function. diff --git a/vllm/entrypoints/openai/chat_completion/serving.py b/vllm/entrypoints/openai/chat_completion/serving.py index 493c26d3aed9..a426836afd35 100644 --- a/vllm/entrypoints/openai/chat_completion/serving.py +++ b/vllm/entrypoints/openai/chat_completion/serving.py @@ -548,7 +548,7 @@ async def chat_completion_stream_generator( # all_previous_token_ids will not be used twice in the same iteration. if tool_choice_auto or reasoning_parser: # These are only required in "auto" tool choice case - all_previous_token_ids = [[]] * num_choices + all_previous_token_ids = [[] for _ in range(num_choices)] # For reasoning parser and tool call all enabled added_content_delta_arr = [False] * num_choices reasoning_end_arr = [False] * num_choices @@ -566,7 +566,8 @@ async def chat_completion_stream_generator( tool_parsers: list[ToolParser | None] = [ self.tool_parser(tokenizer, request.tools) - ] * num_choices + for _ in range(num_choices) + ] else: tool_parsers = [None] * num_choices except Exception as e: From 7e76af14fabfc72144c627331b87d581df361158 Mon Sep 17 00:00:00 2001 From: aliialsaeedii Date: Mon, 30 Mar 2026 12:26:46 +0200 Subject: [PATCH 0016/1601] [Bugfix][Frontend] Return 400 for corrupt/truncated image inputs instead of 500 (#38253) Signed-off-by: aliialsaeedii --- tests/multimodal/media/test_image.py | 74 ++++++++++++++++++++++++++++ vllm/multimodal/media/image.py | 14 +++--- 2 files changed, 82 insertions(+), 6 deletions(-) diff --git a/tests/multimodal/media/test_image.py b/tests/multimodal/media/test_image.py index 065a40d68e35..65196d7805cc 100644 --- a/tests/multimodal/media/test_image.py +++ b/tests/multimodal/media/test_image.py @@ -131,3 +131,77 @@ def test_image_media_io_rgba_background_color_validation(): ImageMediaIO(rgba_background_color=(0, 0, 0)) # Should not raise ImageMediaIO(rgba_background_color=[255, 255, 255]) # Should not raise ImageMediaIO(rgba_background_color=(128, 128, 128)) # Should not raise + + +def test_image_media_io_load_bytes(tmp_path): + """Test load_bytes with valid and invalid image data.""" + # Save a valid RGB image to use as source bytes + valid_image = Image.new("RGB", (8, 8), (100, 150, 200)) + valid_path = tmp_path / "valid.png" + valid_image.save(valid_path) + + valid_data = valid_path.read_bytes() + + # Test 1: Valid image bytes load successfully and are fully decoded + image_io = ImageMediaIO() + result = image_io.load_bytes(valid_data) + + # Check the returned media is a properly loaded image + assert isinstance(result.media, Image.Image) + assert result.media.size == (8, 8) + assert result.media.getpixel((0, 0)) == (100, 150, 200) + + # Test 2: Garbage bytes raise ValueError + with pytest.raises(ValueError, match="Failed to load image"): + image_io.load_bytes(b"not an image") + + # Test 3: Truncated PNG header raises ValueError + with pytest.raises(ValueError, match="Failed to load image"): + image_io.load_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 10) + + # Test 4: Real PNG truncated mid-stream raises ValueError + with pytest.raises(ValueError, match="Failed to load image"): + image_io.load_bytes(valid_data[: len(valid_data) // 2]) + + # Test 5: Empty bytes raise ValueError + with pytest.raises(ValueError, match="Failed to load image"): + image_io.load_bytes(b"") + + +def test_image_media_io_load_file(tmp_path): + """Test load_file with valid and invalid image files.""" + # Save a valid RGB image to disk + valid_image = Image.new("RGB", (4, 4), (10, 20, 30)) + valid_path = tmp_path / "valid.png" + valid_image.save(valid_path) + + # Test 1: Valid image file loads successfully and is fully decoded + image_io = ImageMediaIO() + result = image_io.load_file(valid_path) + + # Check the returned media is a properly loaded image + assert isinstance(result.media, Image.Image) + assert result.media.size == (4, 4) + assert result.media.getpixel((0, 0)) == (10, 20, 30) + + # Test 2: File with garbage content raises ValueError + bad_file = tmp_path / "bad.png" + bad_file.write_bytes(b"this is not an image") + + with pytest.raises(ValueError, match="Failed to load image"): + image_io.load_file(bad_file) + + # Test 3: File with truncated PNG header raises ValueError + truncated_file = tmp_path / "truncated.png" + truncated_file.write_bytes(b"\x89PNG\r\n\x1a\n" + b"\x00" * 10) + + with pytest.raises(ValueError, match="Failed to load image"): + image_io.load_file(truncated_file) + + # Test 4: Real PNG file truncated mid-stream raises ValueError + valid_data = valid_path.read_bytes() + truncated_real_file = tmp_path / "truncated_real.png" + truncated_real_file.write_bytes(valid_data[: len(valid_data) // 2]) + + with pytest.raises(ValueError, match="Failed to load image"): + image_io.load_file(truncated_real_file) diff --git a/vllm/multimodal/media/image.py b/vllm/multimodal/media/image.py index ea4bf7b01527..ea816b760fea 100644 --- a/vllm/multimodal/media/image.py +++ b/vllm/multimodal/media/image.py @@ -68,17 +68,19 @@ def _convert_image_mode( return convert_image_mode(image, self.image_mode) def load_bytes(self, data: bytes) -> MediaWithBytes[Image.Image]: - image = Image.open(BytesIO(data)) - return MediaWithBytes(self._convert_image_mode(image), data) + try: + image = Image.open(BytesIO(data)) + image.load() + image = self._convert_image_mode(image) + except (OSError, Image.UnidentifiedImageError) as e: + raise ValueError(f"Failed to load image: {e}") from e + return MediaWithBytes(image, data) def load_base64(self, media_type: str, data: str) -> MediaWithBytes[Image.Image]: return self.load_bytes(pybase64.b64decode(data, validate=True)) def load_file(self, filepath: Path) -> MediaWithBytes[Image.Image]: - with open(filepath, "rb") as f: - data = f.read() - image = Image.open(BytesIO(data)) - return MediaWithBytes(self._convert_image_mode(image), data) + return self.load_bytes(filepath.read_bytes()) def encode_base64( self, From 1031c84c360874ddc37589f21b15686788ac142e Mon Sep 17 00:00:00 2001 From: Collin McCarthy Date: Mon, 30 Mar 2026 04:09:45 -0700 Subject: [PATCH 0017/1601] Fix ambiguous num_blocks for hybrid attn mamba (#37236) Signed-off-by: Collin McCarthy Signed-off-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> Co-authored-by: Netanel Haber <58652339+netanel-haber@users.noreply.github.com> --- tests/v1/worker/test_gpu_model_runner.py | 32 ++++++++++++++++++++++- vllm/v1/worker/gpu_model_runner.py | 33 ++++++++++++++---------- 2 files changed, 51 insertions(+), 14 deletions(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 02fac6dfc142..0fbd6605a7fe 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -1,6 +1,8 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from types import SimpleNamespace + import numpy as np import pytest import torch @@ -30,6 +32,7 @@ from vllm.v1.core.kv_cache_utils import estimate_max_model_len, get_kv_cache_configs from vllm.v1.core.sched.output import CachedRequestData, NewRequestData, SchedulerOutput from vllm.v1.kv_cache_interface import ( + AttentionSpec, FullAttentionSpec, KVCacheConfig, KVCacheGroupSpec, @@ -38,7 +41,7 @@ from vllm.v1.sample.metadata import SamplingMetadata from vllm.v1.worker.gpu_input_batch import InputBatch from vllm.v1.worker.gpu_model_runner import GPUModelRunner -from vllm.v1.worker.utils import select_common_block_size +from vllm.v1.worker.utils import AttentionGroup, select_common_block_size BLOCK_SIZE = 16 NUM_BLOCKS = 10 @@ -946,6 +949,33 @@ def test_hybrid_attention_mamba_tensor_shapes(): assert torch.equal(actual_ssm, expected_ssm) +def test_update_hybrid_attention_mamba_layout_with_num_block_2_rewrites_stride(): + from vllm.v1.attention.backends.flash_attn import FlashAttentionBackend + + ambiguous_cache = torch.empty((2, 2, BLOCK_SIZE, 1, 8), dtype=torch.float16) + """Ambiguous, because both dims[0=kv_dim] and dims[1=num_blocks] == 2""" + hidden_size = ambiguous_cache.shape[2:].numel() + assert ambiguous_cache.stride()[:2] == (2 * hidden_size, hidden_size) + + attention_spec = AttentionSpec( + block_size=BLOCK_SIZE, num_kv_heads=1, head_size=8, dtype=torch.float16 + ) + runner_stub = SimpleNamespace( + cache_config=SimpleNamespace(cache_dtype="auto"), + _kv_cache_spec_attn_group_iterator=lambda: iter( + [AttentionGroup(FlashAttentionBackend, ["attn"], attention_spec, 0)] + ), + ) + GPUModelRunner._update_hybrid_attention_mamba_layout( + runner_stub, {"attn": ambiguous_cache}, [BLOCK_SIZE] + ) + + assert ambiguous_cache.stride()[:2] == (hidden_size, 2 * hidden_size), """\ + We expect _update_hybrid_attention_mamba_layout to re-stride the cache from: + (2, num_blocks) -> (num_blocks, 2), even when num_blocks==2, + which was ambiguous before get_kv_cache_block_dim was used""" + + def test_hybrid_block_table_initialization(): """Test hybrid block table with different kernel and kvcache_manager block sizes.""" diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 8cfa61baa599..1f946cda0eea 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -6653,12 +6653,12 @@ def _reshape_kv_cache_tensors( raise NotImplementedError if has_attn and has_mamba: - self._update_hybrid_attention_mamba_layout(kv_caches) + self._update_hybrid_attention_mamba_layout(kv_caches, kernel_block_sizes) return kv_caches def _update_hybrid_attention_mamba_layout( - self, kv_caches: dict[str, torch.Tensor] + self, kv_caches: dict[str, torch.Tensor], kernel_block_sizes: list[int] ) -> None: """ Update the layout of attention layers from (2, num_blocks, ...) to @@ -6666,23 +6666,30 @@ def _update_hybrid_attention_mamba_layout( Args: kv_caches: The KV cache buffer of each layer. + kernel_block_sizes: The kernel block sizes for each KV cache group. """ for group in self._kv_cache_spec_attn_group_iterator(): kv_cache_spec = group.kv_cache_spec + if not isinstance(kv_cache_spec, AttentionSpec): + continue + block_dim = group.backend.get_kv_cache_block_dim( + kernel_block_sizes[group.kv_cache_group_id], + kv_cache_spec.num_kv_heads, + kv_cache_spec.head_size, + cache_dtype_str=self.cache_config.cache_dtype, + ) + # block_dim: 0 means (num_blocks, 2, ...); 1 means (2, num_blocks, ...). + if block_dim == 0: + continue + assert block_dim == 1 for layer_name in group.layer_names: kv_cache = kv_caches[layer_name] - if isinstance(kv_cache_spec, AttentionSpec) and kv_cache.shape[0] == 2: - assert kv_cache.shape[1] != 2, ( - "Fail to determine whether the layout is " - "(2, num_blocks, ...) or (num_blocks, 2, ...) for " - f"a tensor of shape {kv_cache.shape}" - ) - hidden_size = kv_cache.shape[2:].numel() - kv_cache.as_strided_( - size=kv_cache.shape, - stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]), - ) + hidden_size = kv_cache.shape[2:].numel() + kv_cache.as_strided_( + size=kv_cache.shape, + stride=(hidden_size, 2 * hidden_size, *kv_cache.stride()[2:]), + ) def initialize_kv_cache_tensors( self, kv_cache_config: KVCacheConfig, kernel_block_sizes: list[int] From 677424c7acd9fb7477294017c99f798588002d4f Mon Sep 17 00:00:00 2001 From: Andreas Karatzas Date: Mon, 30 Mar 2026 06:58:53 -0500 Subject: [PATCH 0018/1601] [Core][CI] Add opt-in media URL caching via VLLM_MEDIA_CACHE (#37123) Signed-off-by: Andreas Karatzas --- tests/multimodal/media/test_connector.py | 112 +++++++++++++++++++ vllm/envs.py | 19 ++++ vllm/multimodal/media/connector.py | 135 +++++++++++++++++++++++ 3 files changed, 266 insertions(+) diff --git a/tests/multimodal/media/test_connector.py b/tests/multimodal/media/test_connector.py index c771cc9a3fdf..b78d24d189f8 100644 --- a/tests/multimodal/media/test_connector.py +++ b/tests/multimodal/media/test_connector.py @@ -4,6 +4,8 @@ import asyncio import mimetypes import os +import shutil +import time from tempfile import NamedTemporaryFile, TemporaryDirectory import aiohttp @@ -375,3 +377,113 @@ async def test_ssrf_bypass_backslash_disallowed_domain(): with pytest.raises(ValueError, match="allowed domains"): await connector.fetch_image_async(bypass_url) + + +def _make_cached_connector(cache_dir, *, max_mb=10, ttl_hours=24): + """Create a MediaConnector with caching enabled via monkeypatched internals. + + We bypass __init__'s env-var path and wire up the cache fields directly + so tests don't depend on environment variables. URLs in these tests are + only used as cache keys (hashed to derive filenames); no HTTP requests + are made. + """ + connector = MediaConnector() + connector._media_cache_dir = cache_dir + connector._media_cache_max_bytes = max_mb * 1024 * 1024 + connector._media_cache_ttl_secs = ttl_hours * 3600 + return connector + + +def test_cache_put_and_get(): + """Basic round-trip: put bytes, get them back.""" + with TemporaryDirectory() as cache_dir: + connector = _make_cached_connector(cache_dir) + url = "https://example.com/image.png" + data = b"fake-image-bytes" + + connector._put_cached_bytes(url, data) + cached = connector._get_cached_bytes(url) + assert cached == data + + +def test_cache_ttl_expiry(): + """Entries older than TTL are evicted on read.""" + with TemporaryDirectory() as cache_dir: + connector = _make_cached_connector(cache_dir, ttl_hours=24) + url = "https://example.com/old.png" + data = b"old-data" + + connector._put_cached_bytes(url, data) + + # Backdate the file's mtime so it appears expired + cache_path = connector._media_cache_path(url) + expired_time = time.time() - (25 * 3600) # 25 hours ago + os.utime(cache_path, (expired_time, expired_time)) + + assert connector._get_cached_bytes(url) is None + assert not cache_path.exists() + + +def test_cache_lru_eviction(): + """Oldest entries are evicted when cache exceeds size budget.""" + with TemporaryDirectory() as cache_dir: + # Set a very small max size: 100 bytes + connector = _make_cached_connector(cache_dir, max_mb=0) + connector._media_cache_max_bytes = 100 + + # Write three 50-byte entries (total 150 > 100 budget) + urls = [f"https://example.com/{i}.png" for i in range(3)] + for i, url in enumerate(urls): + connector._put_cached_bytes(url, b"x" * 50) + # Stagger mtime so eviction order is deterministic + path = connector._media_cache_path(url) + os.utime(path, (time.time() + i, time.time() + i)) + + # The oldest entry (urls[0]) should have been evicted + assert connector._get_cached_bytes(urls[0]) is None + # The newest entries should still be present + assert connector._get_cached_bytes(urls[2]) == b"x" * 50 + + +def test_cache_ttl_eviction_during_write(): + """_maybe_evict removes expired files even if under size budget.""" + with TemporaryDirectory() as cache_dir: + connector = _make_cached_connector(cache_dir, ttl_hours=1) + url_old = "https://example.com/stale.png" + url_new = "https://example.com/fresh.png" + + connector._put_cached_bytes(url_old, b"stale") + # Backdate old entry past TTL + old_path = connector._media_cache_path(url_old) + expired_time = time.time() - (2 * 3600) + os.utime(old_path, (expired_time, expired_time)) + + # Writing a new entry triggers _maybe_evict + connector._put_cached_bytes(url_new, b"fresh") + + assert not old_path.exists() + assert connector._get_cached_bytes(url_new) == b"fresh" + + +def test_put_cached_bytes_missing_dir(): + """_put_cached_bytes does not crash when the cache dir disappears.""" + with TemporaryDirectory() as cache_dir: + connector = _make_cached_connector(cache_dir) + # Remove the directory to simulate it disappearing at runtime + shutil.rmtree(cache_dir) + + # Should not raise (graceful degradation) + connector._put_cached_bytes("https://example.com/x.png", b"data") + + +def test_get_cached_bytes_file_deleted_before_read(): + """_get_cached_bytes returns None if the file vanishes mid-read.""" + with TemporaryDirectory() as cache_dir: + connector = _make_cached_connector(cache_dir) + url = "https://example.com/vanish.png" + + connector._put_cached_bytes(url, b"data") + # Delete the file to simulate concurrent eviction + connector._media_cache_path(url).unlink() + + assert connector._get_cached_bytes(url) is None diff --git a/vllm/envs.py b/vllm/envs.py index 2944bb111d24..0bd8f0fec1f6 100755 --- a/vllm/envs.py +++ b/vllm/envs.py @@ -64,6 +64,9 @@ VLLM_IMAGE_FETCH_TIMEOUT: int = 5 VLLM_VIDEO_FETCH_TIMEOUT: int = 30 VLLM_AUDIO_FETCH_TIMEOUT: int = 10 + VLLM_MEDIA_CACHE: str = "" + VLLM_MEDIA_CACHE_MAX_SIZE_MB: int = 5120 + VLLM_MEDIA_CACHE_TTL_HOURS: float = 24 VLLM_MEDIA_FETCH_MAX_RETRIES: int = 3 VLLM_MEDIA_URL_ALLOW_REDIRECTS: bool = True VLLM_MEDIA_LOADING_THREAD_COUNT: int = 8 @@ -776,6 +779,19 @@ def _get_or_set_default() -> str: "VLLM_AUDIO_FETCH_TIMEOUT": lambda: int( os.getenv("VLLM_AUDIO_FETCH_TIMEOUT", "10") ), + # Directory for caching media downloads (images, video, audio fetched + # from URLs during inference). Empty string disables caching. + "VLLM_MEDIA_CACHE": lambda: os.getenv("VLLM_MEDIA_CACHE", ""), + # Maximum cache size in MB. When exceeded, least-recently-used entries + # are evicted. Default is 5120 (5 GB). + "VLLM_MEDIA_CACHE_MAX_SIZE_MB": lambda: int( + os.getenv("VLLM_MEDIA_CACHE_MAX_SIZE_MB", "5120") + ), + # Time-to-live in hours for cached media files. Entries older than this + # are evicted regardless of cache size. Default is 24 hours. + "VLLM_MEDIA_CACHE_TTL_HOURS": lambda: float( + os.getenv("VLLM_MEDIA_CACHE_TTL_HOURS", "24") + ), # Maximum number of retries for fetching media (images, audio, video) # from URLs. Each retry quadruples the timeout. Default is 3. "VLLM_MEDIA_FETCH_MAX_RETRIES": lambda: int( @@ -1777,6 +1793,9 @@ def compile_factors() -> dict[str, object]: "VLLM_IMAGE_FETCH_TIMEOUT", "VLLM_VIDEO_FETCH_TIMEOUT", "VLLM_AUDIO_FETCH_TIMEOUT", + "VLLM_MEDIA_CACHE", + "VLLM_MEDIA_CACHE_MAX_SIZE_MB", + "VLLM_MEDIA_CACHE_TTL_HOURS", "VLLM_MEDIA_FETCH_MAX_RETRIES", "VLLM_MEDIA_URL_ALLOW_REDIRECTS", "VLLM_MEDIA_LOADING_THREAD_COUNT", diff --git a/vllm/multimodal/media/connector.py b/vllm/multimodal/media/connector.py index 80aaa2a8293e..6ec90a45f1d0 100644 --- a/vllm/multimodal/media/connector.py +++ b/vllm/multimodal/media/connector.py @@ -3,6 +3,11 @@ import asyncio import atexit +import contextlib +import hashlib +import os +import tempfile +import time from concurrent.futures import ThreadPoolExecutor from pathlib import Path from typing import Any, TypeVar @@ -16,6 +21,7 @@ import vllm.envs as envs from vllm.connections import HTTPConnection, global_http_connection +from vllm.logger import init_logger from vllm.utils.registry import ExtensionManager from .audio import AudioEmbeddingMediaIO, AudioMediaIO @@ -23,6 +29,8 @@ from .image import ImageEmbeddingMediaIO, ImageMediaIO from .video import VideoMediaIO +logger = init_logger(__name__) + _M = TypeVar("_M") global_thread_pool = ThreadPoolExecutor( @@ -116,6 +124,115 @@ def __init__( allowed_media_domains = [] self.allowed_media_domains = allowed_media_domains + # Media download cache (opt-in via VLLM_MEDIA_CACHE) + self._media_cache_dir: str | None = None + self._media_cache_max_bytes: int = 0 + self._media_cache_ttl_secs: float = 0 + media_cache = envs.VLLM_MEDIA_CACHE + if media_cache: + try: + os.makedirs(media_cache, exist_ok=True) + # Verify the directory is writable before enabling caching + with tempfile.NamedTemporaryFile(dir=media_cache, delete=True): + pass + self._media_cache_dir = media_cache + self._media_cache_max_bytes = ( + envs.VLLM_MEDIA_CACHE_MAX_SIZE_MB * 1024 * 1024 + ) + self._media_cache_ttl_secs = envs.VLLM_MEDIA_CACHE_TTL_HOURS * 3600 + logger.info( + "Media cache enabled at %s (max %d MB, TTL %s hours)", + media_cache, + envs.VLLM_MEDIA_CACHE_MAX_SIZE_MB, + envs.VLLM_MEDIA_CACHE_TTL_HOURS, + ) + except OSError: + logger.warning( + "VLLM_MEDIA_CACHE path %s is not writable, media caching disabled", + media_cache, + ) + + def _get_cached_bytes(self, url: str) -> bytes | None: + """Return cached bytes for a URL, or None if not cached/expired.""" + if not self._media_cache_dir: + return None + cache_path = self._media_cache_path(url) + # Check TTL + try: + age = time.time() - cache_path.stat().st_mtime + except OSError: + return None + if age > self._media_cache_ttl_secs: + cache_path.unlink(missing_ok=True) + return None + # Touch mtime for LRU ordering + try: + cache_path.touch() + return cache_path.read_bytes() + except OSError: + return None + + def _put_cached_bytes(self, url: str, data: bytes) -> None: + """Store downloaded bytes and evict if over budget.""" + if not self._media_cache_dir: + return + cache_path = self._media_cache_path(url) + # Atomic write via temp file + rename + tmp_path = None + try: + with tempfile.NamedTemporaryFile( + mode="wb", dir=self._media_cache_dir, delete=False + ) as tmp_file: + tmp_file.write(data) + tmp_path = tmp_file.name + os.rename(tmp_path, str(cache_path)) + except OSError: + # Another process beat us or disk issue + if tmp_path is not None: + with contextlib.suppress(OSError): + os.remove(tmp_path) + return + self._maybe_evict(exclude=cache_path) + + def _maybe_evict(self, exclude: Path | None = None) -> None: + """Evict expired entries first, then LRU until under size limit.""" + cache_dir = Path(self._media_cache_dir) # type: ignore[arg-type] + entries = [] + expired = [] + total_size = 0 + now = time.time() + for f in cache_dir.iterdir(): + if f.name.startswith("."): + continue + try: + stat = f.stat() + except OSError: + continue + age = now - stat.st_mtime + if age > self._media_cache_ttl_secs: + expired.append(f) + continue + total_size += stat.st_size + # Never evict the file we just wrote + if exclude is not None and f.name == exclude.name: + continue + entries.append((stat.st_mtime, stat.st_size, f)) + + # Evict items according to LRU policy + entries.sort(key=lambda e: e[0], reverse=True) + while total_size > self._media_cache_max_bytes and entries: + mtime, size, f = entries.pop() + expired.append(f) + total_size -= size + + for f in expired: + f.unlink(missing_ok=True) + + def _media_cache_path(self, url: str) -> Path: + url_hash = hashlib.sha256(url.encode()).hexdigest()[:20] + ext = Path(url.split("?", 1)[0]).suffix or "" + return Path(self._media_cache_dir) / f"{url_hash}{ext}" # type: ignore[arg-type] + def _load_data_url( self, url_spec: Url, @@ -178,6 +295,10 @@ def load_from_url( if url_spec.scheme and url_spec.scheme.startswith("http"): self._assert_url_in_allowed_media_domains(url_spec) + cached = self._get_cached_bytes(url) + if cached is not None: + return media_io.load_bytes(cached) + connection = self.connection data = connection.get_bytes( url_spec.url, @@ -185,6 +306,7 @@ def load_from_url( allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS, ) + self._put_cached_bytes(url, data) return media_io.load_bytes(data) if url_spec.scheme == "data": @@ -209,12 +331,25 @@ async def load_from_url_async( if url_spec.scheme and url_spec.scheme.startswith("http"): self._assert_url_in_allowed_media_domains(url_spec) + cached = await loop.run_in_executor( + global_thread_pool, self._get_cached_bytes, url + ) + if cached is not None: + future = loop.run_in_executor( + global_thread_pool, media_io.load_bytes, cached + ) + return await future + connection = self.connection data = await connection.async_get_bytes( url_spec.url, timeout=fetch_timeout, allow_redirects=envs.VLLM_MEDIA_URL_ALLOW_REDIRECTS, ) + + await loop.run_in_executor( + global_thread_pool, self._put_cached_bytes, url, data + ) future = loop.run_in_executor(global_thread_pool, media_io.load_bytes, data) return await future From 6557f4937fd2937ae4824beb492ff67625895d89 Mon Sep 17 00:00:00 2001 From: "Li, Jiang" Date: Mon, 30 Mar 2026 20:13:00 +0800 Subject: [PATCH 0019/1601] [Bugfix][CPU] Skip set_num_threads after thread binding (#38535) Signed-off-by: jiang1.li --- .../hardware_ci/run-cpu-distributed-smoke-test.sh | 5 +++-- vllm/v1/worker/cpu_worker.py | 9 +++++++++ 2 files changed, 12 insertions(+), 2 deletions(-) diff --git a/.buildkite/scripts/hardware_ci/run-cpu-distributed-smoke-test.sh b/.buildkite/scripts/hardware_ci/run-cpu-distributed-smoke-test.sh index f289a43c6be4..d905403164a2 100644 --- a/.buildkite/scripts/hardware_ci/run-cpu-distributed-smoke-test.sh +++ b/.buildkite/scripts/hardware_ci/run-cpu-distributed-smoke-test.sh @@ -1,9 +1,10 @@ #!/bin/bash set -euox pipefail export VLLM_CPU_CI_ENV=0 +export VLLM_CPU_KVCACHE_SPACE=1 # avoid OOM echo "--- PP+TP" -vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 & +vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -pp=2 --max-model-len=4096 & server_pid=$! timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1 vllm bench serve \ @@ -23,7 +24,7 @@ if [ "$failed_req" -ne 0 ]; then fi echo "--- DP+TP" -vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 & +vllm serve meta-llama/Llama-3.2-3B-Instruct -tp=2 -dp=2 --max-model-len=4096 & server_pid=$! timeout 600 bash -c "until curl localhost:8000/v1/models > /dev/null 2>&1; do sleep 1; done" || exit 1 vllm bench serve \ diff --git a/vllm/v1/worker/cpu_worker.py b/vllm/v1/worker/cpu_worker.py index 122cacd14cd8..2547751c0d81 100644 --- a/vllm/v1/worker/cpu_worker.py +++ b/vllm/v1/worker/cpu_worker.py @@ -108,6 +108,15 @@ def check_preloaded_libs(name: str): if ret: logger.info(ret) + # After the thread binding, changing thread num is not allowed + def skip_set_num_threads(x: int): + logger.warning( + "CPU backend doesn't allow to use " + "`torch.set_num_threads` after the thread binding, skip it." + ) + + torch.set_num_threads = skip_set_num_threads + # Note: unique identifier for creating allreduce shared memory os.environ["VLLM_DIST_IDENT"] = self.distributed_init_method.split(":")[-1] # Initialize the distributed environment. From 7c3f88b2a895f58d653e9284dbb7b1da85c85e73 Mon Sep 17 00:00:00 2001 From: Thomas Parnell Date: Mon, 30 Mar 2026 14:32:26 +0200 Subject: [PATCH 0020/1601] [Bugfix] Remove false-positive format mismatch warnings in FLA ops (#38255) Signed-off-by: Thomas Parnell --- vllm/model_executor/layers/fla/ops/chunk.py | 8 -------- vllm/model_executor/layers/fla/ops/cumsum.py | 9 --------- 2 files changed, 17 deletions(-) diff --git a/vllm/model_executor/layers/fla/ops/chunk.py b/vllm/model_executor/layers/fla/ops/chunk.py index 9261885956e5..73cba7f9035c 100644 --- a/vllm/model_executor/layers/fla/ops/chunk.py +++ b/vllm/model_executor/layers/fla/ops/chunk.py @@ -7,7 +7,6 @@ # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 -import warnings import torch @@ -184,13 +183,6 @@ def chunk_gated_delta_rule( "ChunkGatedDeltaRuleFunction does not support float32. Please use bfloat16." ) assert len(beta.shape) == 3, "beta must be of shape [B, T, H]." - if q.shape[1] < q.shape[2]: - warnings.warn( - f"Input tensor shape suggests potential format mismatch: seq_len ({q.shape[1]}) < num_heads ({q.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "Please verify your input tensor format matches the expected shape [B, T, H, ...].", - stacklevel=2, - ) if cu_seqlens is not None: if q.shape[0] != 1: raise ValueError( diff --git a/vllm/model_executor/layers/fla/ops/cumsum.py b/vllm/model_executor/layers/fla/ops/cumsum.py index 99b41794796d..13238020cbd9 100644 --- a/vllm/model_executor/layers/fla/ops/cumsum.py +++ b/vllm/model_executor/layers/fla/ops/cumsum.py @@ -7,7 +7,6 @@ # the following copyright notice: # Copyright (c) 2023-2025, Songlin Yang, Yu Zhang # ruff: noqa: E501 -import warnings import torch @@ -252,14 +251,6 @@ def chunk_local_cumsum( output_dtype: torch.dtype | None = torch.float, **kwargs, ) -> torch.Tensor: - if not head_first and g.shape[1] < g.shape[2]: - warnings.warn( - f"Input tensor shape suggests potential format mismatch: seq_len ({g.shape[1]}) < num_heads ({g.shape[2]}). " - "This may indicate the inputs were passed in head-first format [B, H, T, ...] " - "when head_first=False was specified. " - "Please verify your input tensor format matches the expected shape [B, T, H, ...].", - stacklevel=2, - ) if cu_seqlens is not None: assert g.shape[0] == 1, ( "Only batch size 1 is supported when cu_seqlens are provided" From 246dc7d864c87f2b7eaf7ecbbd068f16e8b7a249 Mon Sep 17 00:00:00 2001 From: tomeras91 <57313761+tomeras91@users.noreply.github.com> Date: Mon, 30 Mar 2026 16:12:17 +0300 Subject: [PATCH 0021/1601] [Misc] Add @tomeras91 as a maintainer of Nemotron related code + mamba block (#38547) Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> --- .github/CODEOWNERS | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 37707da8af26..e492811b6e1d 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -8,7 +8,7 @@ /vllm/model_executor/layers/attention @LucasWilkinson @MatthewBonanni /vllm/model_executor/layers/fused_moe @mgoin @pavanimajety /vllm/model_executor/layers/quantization @mgoin @robertgshaw2-redhat @tlrmchlsmth @yewentao256 @pavanimajety -/vllm/model_executor/layers/mamba @tdoublep +/vllm/model_executor/layers/mamba @tdoublep @tomeras91 /vllm/model_executor/layers/mamba/gdn_linear_attn.py @tdoublep @ZJY0516 /vllm/model_executor/model_loader @22quinn /vllm/model_executor/layers/batch_invariant.py @yewentao256 @@ -82,7 +82,7 @@ CMakeLists.txt @tlrmchlsmth @LucasWilkinson /tests/v1/core @WoosukKwon @robertgshaw2-redhat @njhill @ywang96 @alexm-redhat @heheda12345 @ApostaC @orozery /tests/weight_loading @mgoin @youkaichao @yewentao256 /tests/lora @jeejeelee -/tests/models/language/generation/test_hybrid.py @tdoublep +/tests/models/language/generation/test_hybrid.py @tdoublep @tomeras91 /tests/v1/kv_connector/nixl_integration @NickLucche /tests/v1/kv_connector @ApostaC @orozery /tests/v1/kv_offload @ApostaC @orozery @@ -126,6 +126,11 @@ mkdocs.yaml @hmellor /vllm/platforms/xpu.py @jikunshang /docker/Dockerfile.xpu @jikunshang +# Nemotron-specific files +/vllm/model_executor/models/*nemotron* @tomeras91 +/vllm/transformers_utils/configs/*nemotron* @tomeras91 +/tests/**/*nemotron* @tomeras91 + # Qwen-specific files /vllm/attention/backends/dual_chunk_flash_attn.py @sighingnow /vllm/model_executor/models/qwen* @sighingnow From e8b055a5ac1d16285545087d60f898085a485959 Mon Sep 17 00:00:00 2001 From: Matthias Gehre Date: Mon, 30 Mar 2026 16:30:52 +0200 Subject: [PATCH 0022/1601] [Bugfix] Handle ParallelLMHead in compressed-tensors get_quant_method (#37291) Signed-off-by: Matthias Gehre Co-authored-by: Michael Goin --- tests/quantization/test_compressed_tensors.py | 79 ++++++++++++++++++- .../compressed_tensors/compressed_tensors.py | 10 +++ 2 files changed, 88 insertions(+), 1 deletion(-) diff --git a/tests/quantization/test_compressed_tensors.py b/tests/quantization/test_compressed_tensors.py index 913b755b9074..f23506b00a7c 100644 --- a/tests/quantization/test_compressed_tensors.py +++ b/tests/quantization/test_compressed_tensors.py @@ -5,13 +5,20 @@ Run `pytest tests/quantization/test_compressed_tensors.py`. """ +from unittest.mock import Mock + import pytest import torch -from compressed_tensors.quantization import QuantizationType +from compressed_tensors.quantization import ( + QuantizationArgs, + QuantizationStrategy, + QuantizationType, +) from tests.models.utils import check_logprobs_close from vllm.model_executor.layers.fused_moe import UnquantizedFusedMoEMethod from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501 + CompressedTensorsConfig, CompressedTensorsLinearMethod, CompressedTensorsW4A4Fp4, CompressedTensorsW4A8Fp8, @@ -26,6 +33,7 @@ from vllm.model_executor.layers.quantization.utils.nvfp4_utils import ( cutlass_fp4_supported, ) +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.platforms import current_platform from vllm.v1.attention.backends.fa_utils import get_flash_attn_version @@ -558,3 +566,72 @@ def test_w4a16_moe_torch_compile(vllm_runner): ) as llm: output = llm.generate_greedy("Hi", max_tokens=1) assert output + + +def _make_ct_config(*, target: str = "Linear") -> CompressedTensorsConfig: + """Build a minimal CompressedTensorsConfig with INT8 channel quant.""" + weight_quant = QuantizationArgs( + num_bits=8, + type=QuantizationType.INT, + strategy=QuantizationStrategy.CHANNEL, + symmetric=True, + dynamic=False, + ) + return CompressedTensorsConfig( + target_scheme_map={ + target: { + "weights": weight_quant, + "input_activations": None, + "format": "pack-quantized", + } + }, + ignore=[], + quant_format="pack-quantized", + sparsity_scheme_map={}, + sparsity_ignore_list=[], + ) + + +def test_get_quant_method_returns_linear_method_for_parallel_lm_head(): + """ParallelLMHead whose name matches a target must get a quantised method.""" + config = _make_ct_config(target="re:.*lm_head") + mock_lm_head = Mock(spec=ParallelLMHead) + mock_lm_head.__class__ = ParallelLMHead + + method = config.get_quant_method(mock_lm_head, prefix="model.lm_head") + + assert isinstance(method, CompressedTensorsLinearMethod), ( + f"Expected CompressedTensorsLinearMethod, got {type(method).__name__}" + ) + + +def test_get_quant_method_returns_none_for_ignored_parallel_lm_head(): + """ParallelLMHead on the ignore list should be left unquantized (None).""" + config = _make_ct_config(target="re:.*lm_head") + config.ignore = ["re:.*lm_head"] + mock_lm_head = Mock(spec=ParallelLMHead) + mock_lm_head.__class__ = ParallelLMHead + + method = config.get_quant_method(mock_lm_head, prefix="model.lm_head") + + assert method is None, ( + f"Expected None for ignored ParallelLMHead, got {type(method).__name__}" + ) + + +def test_get_quant_method_returns_none_for_unmatched_parallel_lm_head(): + """ParallelLMHead with target='Linear' (typical real model) must not crash. + + Most compressed-tensors models only target 'Linear'. ParallelLMHead does + not match that target, so get_quant_method should return None (unquantized) + instead of raising ValueError. + """ + config = _make_ct_config(target="Linear") + mock_lm_head = Mock(spec=ParallelLMHead) + mock_lm_head.__class__ = ParallelLMHead + + method = config.get_quant_method(mock_lm_head, prefix="model.lm_head") + + assert method is None, ( + f"Expected None for unmatched ParallelLMHead, got {type(method).__name__}" + ) diff --git a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py index 4fcc468c6cfb..6ca65cdb1887 100644 --- a/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +++ b/vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py @@ -62,6 +62,7 @@ should_ignore_layer, ) from vllm.model_executor.layers.quantization.kv_cache import BaseKVCacheMethod +from vllm.model_executor.layers.vocab_parallel_embedding import ParallelLMHead from vllm.platforms import current_platform if TYPE_CHECKING: @@ -179,6 +180,15 @@ def get_quant_method( else: return quant_method + if isinstance(layer, ParallelLMHead): + try: + quant_scheme = self.get_scheme(layer=layer, layer_name=prefix) + except ValueError: + quant_scheme = None + if quant_scheme is not None: + layer.scheme = quant_scheme + return CompressedTensorsLinearMethod(self) + if isinstance(layer, Attention): return CompressedTensorsKVCacheMethod(self) if isinstance(layer, FusedMoE): From dbdd9ae06738bb8dc5b3fe6c99b3dd5bcd5d5526 Mon Sep 17 00:00:00 2001 From: Hongxia Yang <62075498+hongxiayang@users.noreply.github.com> Date: Mon, 30 Mar 2026 11:49:23 -0400 Subject: [PATCH 0023/1601] [ROCm][Bugfix] fix exception related to trust_remote_code for MiniMax-M2.1-MXFP4 (#37698) Signed-off-by: Hongxia Yang Co-authored-by: Hongxia Yang --- .../test_quark_maybe_update_config.py | 63 +++++++++++++++++++ vllm/config/vllm.py | 5 +- .../model_executor/layers/quantization/awq.py | 8 ++- .../layers/quantization/awq_marlin.py | 8 ++- .../layers/quantization/base_config.py | 16 ++++- .../layers/quantization/cpu_wna16.py | 8 ++- .../layers/quantization/gptq.py | 8 ++- .../layers/quantization/gptq_marlin.py | 8 ++- .../layers/quantization/quark/quark.py | 37 +++++++---- 9 files changed, 142 insertions(+), 19 deletions(-) create mode 100644 tests/quantization/test_quark_maybe_update_config.py diff --git a/tests/quantization/test_quark_maybe_update_config.py b/tests/quantization/test_quark_maybe_update_config.py new file mode 100644 index 000000000000..0142e869c22c --- /dev/null +++ b/tests/quantization/test_quark_maybe_update_config.py @@ -0,0 +1,63 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +"""Unit tests for QuarkConfig.maybe_update_config. + +Fetches real HF configs (metadata only, no model weights) to verify +that dynamic_mxfp4_quant is only enabled for DeepSeek-V3-family models. + +Run: pytest tests/quantization/test_quark_maybe_update_config.py -v +""" + +import pytest +from transformers import AutoConfig + +from vllm.model_executor.layers.quantization.quark.quark import QuarkConfig + + +def _make_quark_config() -> QuarkConfig: + """Create a minimal QuarkConfig for testing.""" + return QuarkConfig(quant_config={}, kv_cache_group=[], pack_method="reorder") + + +# --------------------------------------------------------------------------- +# Non-deepseek models must not flip dynamic_mxfp4_quant +# --------------------------------------------------------------------------- +@pytest.mark.parametrize( + "model_name", + ["amd/MiniMax-M2.1-MXFP4"], +) +def test_non_deepseek_model_stays_false(model_name: str): + """Non-deepseek_v3 models must not enable dynamic_mxfp4_quant.""" + hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + qcfg = _make_quark_config() + + qcfg.maybe_update_config(model_name, hf_config=hf_config) + + assert qcfg.dynamic_mxfp4_quant is False + + +# --------------------------------------------------------------------------- +# DeepSeek-V3 family + fp4 must enable dynamic_mxfp4_quant +# --------------------------------------------------------------------------- +@pytest.mark.parametrize( + "model_name", + ["amd/DeepSeek-R1-MXFP4-ASQ"], +) +def test_deepseek_family_fp4_enables_flag(model_name: str): + hf_config = AutoConfig.from_pretrained(model_name, trust_remote_code=True) + qcfg = _make_quark_config() + + qcfg.maybe_update_config(model_name, hf_config=hf_config) + + assert qcfg.dynamic_mxfp4_quant is True + + +# --------------------------------------------------------------------------- +# Missing hf_config → warn and stay False +# --------------------------------------------------------------------------- +def test_missing_hf_config_stays_false(): + qcfg = _make_quark_config() + + qcfg.maybe_update_config("some/model") + + assert qcfg.dynamic_mxfp4_quant is False diff --git a/vllm/config/vllm.py b/vllm/config/vllm.py index b6be7f10bdb0..55f35c13553f 100644 --- a/vllm/config/vllm.py +++ b/vllm/config/vllm.py @@ -526,7 +526,10 @@ def _get_quantization_config( f"method {model_config.quantization}. Supported dtypes: " f"{supported_dtypes}" ) - quant_config.maybe_update_config(model_config.model) + quant_config.maybe_update_config( + model_config.model, + hf_config=model_config.hf_config, + ) return quant_config return None diff --git a/vllm/model_executor/layers/quantization/awq.py b/vllm/model_executor/layers/quantization/awq.py index 3cf3116f0670..58bb75d0a9ed 100644 --- a/vllm/model_executor/layers/quantization/awq.py +++ b/vllm/model_executor/layers/quantization/awq.py @@ -5,6 +5,7 @@ import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE +from transformers import PretrainedConfig from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -146,7 +147,12 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): self.modules_to_not_convert ) - def maybe_update_config(self, model_name: str, revision: str | None = None): + def maybe_update_config( + self, + model_name: str, + hf_config: PretrainedConfig | None = None, + revision: str | None = None, + ): if self.modules_to_not_convert: return diff --git a/vllm/model_executor/layers/quantization/awq_marlin.py b/vllm/model_executor/layers/quantization/awq_marlin.py index 426b9aa71562..03dfaa7949c0 100644 --- a/vllm/model_executor/layers/quantization/awq_marlin.py +++ b/vllm/model_executor/layers/quantization/awq_marlin.py @@ -6,6 +6,7 @@ import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.nn import Parameter +from transformers import PretrainedConfig import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops @@ -332,7 +333,12 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): self.modules_to_not_convert ) - def maybe_update_config(self, model_name: str, revision: str | None = None): + def maybe_update_config( + self, + model_name: str, + hf_config: PretrainedConfig | None = None, + revision: str | None = None, + ): if self.modules_to_not_convert: return diff --git a/vllm/model_executor/layers/quantization/base_config.py b/vllm/model_executor/layers/quantization/base_config.py index 06fe4270c713..eedc62f7d4d5 100644 --- a/vllm/model_executor/layers/quantization/base_config.py +++ b/vllm/model_executor/layers/quantization/base_config.py @@ -7,6 +7,7 @@ import torch from torch import nn +from transformers import PretrainedConfig if TYPE_CHECKING: from vllm.model_executor.layers.quantization import QuantizationMethods @@ -168,10 +169,23 @@ def apply_vllm_mapper( # noqa: B027 # TODO (@kylesayrs): add implementations for all subclasses pass - def maybe_update_config(self, model_name: str): # noqa: B027 + def maybe_update_config( # noqa: B027 + self, + model_name: str, + hf_config: PretrainedConfig | None = None, + revision: str | None = None, + ): """ Interface to update values after config initialization. + + Args: + model_name: The name of the model + hf_config: The Hugging Face config of the model + revision: The revision of the model + Returns: """ + # TODO: revision is never passed currently in vllm.py, + # but is used in subclasses, should we remove this parameter? pass def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool: diff --git a/vllm/model_executor/layers/quantization/cpu_wna16.py b/vllm/model_executor/layers/quantization/cpu_wna16.py index ea7afef27ebd..3dba317438ec 100644 --- a/vllm/model_executor/layers/quantization/cpu_wna16.py +++ b/vllm/model_executor/layers/quantization/cpu_wna16.py @@ -5,6 +5,7 @@ import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE +from transformers import PretrainedConfig from vllm._custom_ops import ( cpu_gemm_wna16, @@ -133,7 +134,12 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): self.modules_to_not_convert ) - def maybe_update_config(self, model_name: str, revision: str | None = None): + def maybe_update_config( + self, + model_name: str, + hf_config: PretrainedConfig | None = None, + revision: str | None = None, + ): if self.modules_to_not_convert: return diff --git a/vllm/model_executor/layers/quantization/gptq.py b/vllm/model_executor/layers/quantization/gptq.py index 154347a930a9..458741478538 100644 --- a/vllm/model_executor/layers/quantization/gptq.py +++ b/vllm/model_executor/layers/quantization/gptq.py @@ -9,6 +9,7 @@ import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE from torch.nn.parameter import Parameter +from transformers import PretrainedConfig from vllm import _custom_ops as ops from vllm.logger import init_logger @@ -193,7 +194,12 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper: "WeightsMapper"): self.modules_in_block_to_quantize ) - def maybe_update_config(self, model_name: str, revision: str | None = None): + def maybe_update_config( + self, + model_name: str, + hf_config: PretrainedConfig | None = None, + revision: str | None = None, + ): if self.modules_in_block_to_quantize: if is_list_of(self.modules_in_block_to_quantize, list): # original modules_in_block_to_quantize: list[list[str]] diff --git a/vllm/model_executor/layers/quantization/gptq_marlin.py b/vllm/model_executor/layers/quantization/gptq_marlin.py index d7b2a366e1f0..8e367c88346f 100644 --- a/vllm/model_executor/layers/quantization/gptq_marlin.py +++ b/vllm/model_executor/layers/quantization/gptq_marlin.py @@ -6,6 +6,7 @@ import torch from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE +from transformers import PretrainedConfig import vllm.model_executor.layers.fused_moe # noqa from vllm import _custom_ops as ops @@ -299,7 +300,12 @@ def apply_vllm_mapper(self, hf_to_vllm_mapper): self.modules_in_block_to_quantize ) - def maybe_update_config(self, model_name: str, revision: str | None = None): + def maybe_update_config( + self, + model_name: str, + hf_config: PretrainedConfig | None = None, + revision: str | None = None, + ): if self.modules_in_block_to_quantize: if is_list_of(self.modules_in_block_to_quantize, list): # original modules_in_block_to_quantize: list[list[str]] diff --git a/vllm/model_executor/layers/quantization/quark/quark.py b/vllm/model_executor/layers/quantization/quark/quark.py index 78c64bac6187..d0362cedcf2b 100644 --- a/vllm/model_executor/layers/quantization/quark/quark.py +++ b/vllm/model_executor/layers/quantization/quark/quark.py @@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, Any, cast import torch +from transformers import PretrainedConfig from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention @@ -36,7 +37,6 @@ ) from vllm.model_executor.models.utils import WeightsMapper from vllm.platforms import current_platform -from vllm.transformers_utils.config import get_config if TYPE_CHECKING: from vllm.model_executor.models.utils import WeightsMapper @@ -45,6 +45,10 @@ logger = init_logger(__name__) +# model_type values that use dynamic MXFP4 re-quantization for +# OCP MX fp4 Quark checkpoints +_DEEPSEEK_V3_FAMILY_MODEL_TYPES = frozenset({"deepseek_v3"}) + class QuarkConfig(QuantizationConfig): def __init__( @@ -63,19 +67,28 @@ def __init__( self.pack_method = pack_method self.dynamic_mxfp4_quant = False - def maybe_update_config(self, model_name: str, revision: str | None = None): - self.hf_config = get_config( - model=model_name, - trust_remote_code=False, # or get from model_config if available - revision=revision, - config_format="auto", - ) + def maybe_update_config( + self, + model_name: str, + hf_config: PretrainedConfig | None = None, + revision: str | None = None, + ): + """Enable dynamic MXFP4 only for DeepSeek-V3-family + fp4 Quark checkpoints.""" - quant_config = getattr(self.hf_config, "quantization_config", None) + if ( + getattr(hf_config, "model_type", None) + not in _DEEPSEEK_V3_FAMILY_MODEL_TYPES + ): + return + + quant_config = getattr(hf_config, "quantization_config", None) if quant_config is not None: - quant_dtype = quant_config["global_quant_config"]["weight"]["dtype"] - model_type = self.hf_config.model_type - if quant_dtype == "fp4" and model_type == "deepseek_v3": + quant_dtype = ( + quant_config.get("global_quant_config", {}) + .get("weight", {}) + .get("dtype") + ) + if quant_dtype == "fp4": self.dynamic_mxfp4_quant = True def get_linear_method(self) -> "QuarkLinearMethod": From 8e6293e838f94430d7aef14cbb7308c10b99ea11 Mon Sep 17 00:00:00 2001 From: roikoren755 <26850796+roikoren755@users.noreply.github.com> Date: Mon, 30 Mar 2026 19:33:49 +0300 Subject: [PATCH 0024/1601] [Mamba] Add stochastic rounding support (#35753) Signed-off-by: Roi Koren --- tests/kernels/mamba/test_mamba_ssm.py | 54 +++++++++++++++++ vllm/config/cache.py | 36 ++++++++++- vllm/engine/arg_utils.py | 13 ++++ .../layers/mamba/mamba_mixer.py | 2 + .../layers/mamba/mamba_mixer2.py | 2 + .../layers/mamba/ops/mamba_ssm.py | 59 ++++++++++++++++++- vllm/model_executor/models/plamo2.py | 2 + 7 files changed, 166 insertions(+), 2 deletions(-) diff --git a/tests/kernels/mamba/test_mamba_ssm.py b/tests/kernels/mamba/test_mamba_ssm.py index e8cbba29f363..065739cf9241 100644 --- a/tests/kernels/mamba/test_mamba_ssm.py +++ b/tests/kernels/mamba/test_mamba_ssm.py @@ -12,6 +12,7 @@ selective_scan_fn, selective_state_update, ) +from vllm.platforms import current_platform from vllm.utils.torch_utils import set_random_seed from vllm.v1.attention.backends.utils import PAD_SLOT_ID @@ -429,6 +430,59 @@ def test_selective_state_update(dim, dstate, has_z, itype): assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) +@pytest.mark.parametrize("philox_rounds", [0, 4]) +@pytest.mark.parametrize("has_z", [False, True]) +@pytest.mark.parametrize("dstate", [16, 64]) +@pytest.mark.parametrize("dim", [2048, 4096]) +@pytest.mark.skipif( + not ( + current_platform.is_cuda() and current_platform.is_device_capability_family(100) + ), + reason="Stochastic rounding in triton is only supported" + " on compute capability 10.0 CUDA devices.", +) +def test_selective_state_update_stochastic_rounding(dim, dstate, has_z, philox_rounds): + device = "cuda" + rtol, atol = 5e-3, 1e-1 + # set seed + set_random_seed(0) + batch_size = 1 + state = torch.randn(batch_size, dim, dstate, dtype=torch.float16, device=device) + x = torch.randn(batch_size, dim, device=device, dtype=torch.bfloat16) + out = torch.empty_like(x) + dt = torch.randn(batch_size, dim, device=device, dtype=torch.bfloat16) + dt_bias = torch.rand(dim, device=device) - 4.0 + A = -torch.rand(dim, dstate, device=device) - 1.0 + B = torch.randn(batch_size, dstate, device=device) + C = torch.randn(batch_size, dstate, device=device) + D = torch.randn(dim, device=device) + z = torch.randn_like(x) if has_z else None + # Reference uses fp32 state to get ground truth + state_ref = state.float() + selective_state_update( + state, + x, + dt, + A, + B, + C, + D=D, + z=z, + dt_bias=dt_bias, + dt_softplus=True, + out=out, + enable_stochastic_rounding=True, + cache_philox_rounds=philox_rounds, + ) + out_ref = selective_state_update_ref( + state_ref, x, dt, A, B, C, D=D, z=z, dt_bias=dt_bias, dt_softplus=True + ) + + assert state.dtype == torch.float16 + assert torch.allclose(state, state_ref.to(torch.float16), rtol=rtol, atol=atol) + assert torch.allclose(out, out_ref, rtol=rtol, atol=atol) + + @pytest.mark.parametrize("itype", [torch.float32, torch.bfloat16]) @pytest.mark.parametrize("has_z", [False, True]) @pytest.mark.parametrize("dstate", [16, 64]) diff --git a/vllm/config/cache.py b/vllm/config/cache.py index dcc93d987eda..1cadb43183d2 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -109,12 +109,20 @@ class CacheConfig: mamba_cache_mode: MambaCacheMode = "none" """The cache strategy for Mamba layers. - "none": set when prefix caching is disabled. - - "all": cache the mamba state of all tokens at position i * block_size. This is + - "all": cache the mamba state of all tokens at position i * block_size. This is the default behavior (for models that support it) when prefix caching is enabled. - "align": only cache the mamba state of the last token of each scheduler step and when the token is at position i * block_size. """ + enable_mamba_cache_stochastic_rounding: bool = False + """Enable stochastic rounding when writing SSM state to fp16 cache. + Uses random bits to unbias the rounding error, which can improve + numerical stability for long sequences.""" + mamba_cache_philox_rounds: int = 0 + """Number of Philox PRNG rounds for stochastic rounding random number + generation. 0 uses the Triton default. Higher values improve randomness + quality at the cost of compute.""" # Will be set after profiling. num_gpu_blocks: int | None = field(default=None, init=False) @@ -231,3 +239,29 @@ def _validate_cache_dtype(cls, cache_dtype: CacheDType) -> CacheDType: "scaling factor." ) return cache_dtype + + def __post_init__(self): + if self.enable_mamba_cache_stochastic_rounding: + from vllm.platforms import current_platform + + if not current_platform.is_cuda(): + raise ValueError( + "Stochastic rounding for Mamba cache is only supported " + "on NVIDIA CUDA platforms. Please do not specify " + "`--enable-mamba-cache-stochastic-rounding`." + ) + if not current_platform.is_device_capability_family(100): + raise ValueError( + "Stochastic rounding for Mamba cache requires compute " + "capability 10.0 (data center Blackwell). The `cvt.rs` PTX " + "instruction is not supported on your GPU. Please do not specify " + "`--enable-mamba-cache-stochastic-rounding`." + ) + if self.mamba_ssm_cache_dtype != "float16": + raise ValueError( + "Stochastic rounding for Mamba cache requires " + "the SSM cache to be float16. Please set it explicitly, " + "by specifying `--mamba-ssm-cache-dtype float16`, or disable " + "stochastic rounding by not specifying " + "`--enable-mamba-cache-stochastic-rounding`." + ) diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index e1772ab1d427..0c9cf2ae9b14 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -604,6 +604,10 @@ class EngineArgs: mamba_ssm_cache_dtype: MambaDType = CacheConfig.mamba_ssm_cache_dtype mamba_block_size: int | None = get_field(CacheConfig, "mamba_block_size") mamba_cache_mode: MambaCacheMode = CacheConfig.mamba_cache_mode + enable_mamba_cache_stochastic_rounding: bool = ( + CacheConfig.enable_mamba_cache_stochastic_rounding + ) + mamba_cache_philox_rounds: int = CacheConfig.mamba_cache_philox_rounds additional_config: dict[str, Any] = get_field(VllmConfig, "additional_config") @@ -1024,6 +1028,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: cache_group.add_argument( "--mamba-cache-mode", **cache_kwargs["mamba_cache_mode"] ) + cache_group.add_argument( + "--enable-mamba-cache-stochastic-rounding", + **cache_kwargs["enable_mamba_cache_stochastic_rounding"], + ) + cache_group.add_argument( + "--mamba-cache-philox-rounds", **cache_kwargs["mamba_cache_philox_rounds"] + ) cache_group.add_argument( "--kv-offloading-size", **cache_kwargs["kv_offloading_size"] ) @@ -1590,6 +1601,8 @@ def create_engine_config( mamba_ssm_cache_dtype=self.mamba_ssm_cache_dtype, mamba_block_size=self.mamba_block_size, mamba_cache_mode=self.mamba_cache_mode, + enable_mamba_cache_stochastic_rounding=self.enable_mamba_cache_stochastic_rounding, + mamba_cache_philox_rounds=self.mamba_cache_philox_rounds, kv_offloading_size=self.kv_offloading_size, kv_offloading_backend=self.kv_offloading_backend, ) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer.py b/vllm/model_executor/layers/mamba/mamba_mixer.py index 82ca367fb26c..d79af2e27876 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer.py @@ -428,6 +428,8 @@ def forward_impl(self, hidden_states: torch.Tensor, output: torch.Tensor): state_batch_indices=state_indices_tensor_d_input, dst_state_batch_indices=state_indices_tensor_d_output, out=scan_outputs_d, + enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding, + cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds, ) scan_outputs_d = scan_outputs_d.transpose(0, 1) diff --git a/vllm/model_executor/layers/mamba/mamba_mixer2.py b/vllm/model_executor/layers/mamba/mamba_mixer2.py index 9486e182ec46..041405b05a08 100644 --- a/vllm/model_executor/layers/mamba/mamba_mixer2.py +++ b/vllm/model_executor/layers/mamba/mamba_mixer2.py @@ -888,6 +888,8 @@ def conv_ssm_forward( num_accepted_tokens=num_accepted_tokens, cu_seqlens=query_start_loc_d, is_blackwell=self.is_blackwell, + enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding, + cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds, ) def get_state_dtype(self) -> tuple[torch.dtype, torch.dtype]: diff --git a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py index 1cd077758326..793471fdaab5 100644 --- a/vllm/model_executor/layers/mamba/ops/mamba_ssm.py +++ b/vllm/model_executor/layers/mamba/ops/mamba_ssm.py @@ -28,6 +28,21 @@ def softplus(dt): return dt +@triton.jit +def convert_rs_fp16x2(x: tl.tensor, rand: tl.tensor) -> tl.tensor: + y = tl.inline_asm_elementwise( + asm="""{ +cvt.rs.f16x2.f32 $0, $2, $1, $3; +}""", + constraints="=r,r,r,r,r", + args=(x, rand), + dtype=tl.float16, + is_pure=True, + pack=2, + ) + return y + + @triton.heuristics({"HAS_DT_BIAS": lambda args: args["dt_bias_ptr"] is not None}) @triton.heuristics({"HAS_D": lambda args: args["D_ptr"] is not None}) @triton.heuristics({"HAS_Z": lambda args: args["z_ptr"] is not None}) @@ -48,6 +63,7 @@ def softplus(dt): def _selective_scan_update_kernel( # Pointers to matrices state_ptr, + rand_seed_ptr, x_ptr, dt_ptr, dt_bias_ptr, @@ -113,6 +129,8 @@ def _selective_scan_update_kernel( IS_SPEC_DECODING: tl.constexpr, IS_VARLEN: tl.constexpr, BLOCK_SIZE_DSTATE: tl.constexpr, + USE_RS_ROUNDING: tl.constexpr, + PHILOX_ROUNDS: tl.constexpr, ): pid_m = tl.program_id(axis=0) pid_b = tl.program_id(axis=1) @@ -267,7 +285,35 @@ def _selective_scan_update_kernel( z_ptr += stride_z_batch if not IS_SPEC_DECODING: - tl.store(dst_state_ptrs, state.to(dst_state_ptrs.dtype.element_ty), mask=mask) + if USE_RS_ROUNDING: + # Load random seed + rand_seed = tl.load(rand_seed_ptr) + # Generate random offsets for each element in state + if HAS_STATE_BATCH_INDICES: + rand_offsets = ( + state_batch_idx * stride_state_batch + pid_h * stride_state_head + ) + else: + rand_offsets = pid_b * stride_state_batch + pid_h * stride_state_head + rand_offsets += ( + offs_m[:, None] * stride_state_dim + + offs_n[None, :] * stride_state_dstate + ) + # Generate random 32-bits for each element in state + if PHILOX_ROUNDS > 0: + rand = tl.randint(rand_seed, rand_offsets, PHILOX_ROUNDS) + else: + rand = tl.randint(rand_seed, rand_offsets) + # Convert state to fp16 with RS rounding + state = convert_rs_fp16x2(state, rand) + tl.static_assert(state.dtype == tl.float16, "state must be fp16") + tl.static_assert( + dst_state_ptrs.dtype.element_ty == tl.float16, + "dst_state_ptrs must be fp16", + ) + else: + state = state.to(dst_state_ptrs.dtype.element_ty) + tl.store(dst_state_ptrs, state, mask=mask) def selective_state_update( @@ -288,6 +334,8 @@ def selective_state_update( num_accepted_tokens=None, cu_seqlens=None, is_blackwell=False, + enable_stochastic_rounding=False, + cache_philox_rounds=0, ): """ Argument: @@ -419,9 +467,16 @@ def selective_state_update( and dt.stride(-1) == 0 and dt_bias.stride(-1) == 0 ) + rand_seed = ( + torch.randint(0, 2**32, (1,), device=state.device) + if enable_stochastic_rounding + else None + ) + with torch.accelerator.device_index(x.device.index): _selective_scan_update_kernel[grid]( state, + rand_seed, x, dt, dt_bias, @@ -476,6 +531,8 @@ def selective_state_update( tie_hdim, BLOCK_SIZE_M, num_warps=num_warps, + USE_RS_ROUNDING=enable_stochastic_rounding, + PHILOX_ROUNDS=cache_philox_rounds, ) diff --git a/vllm/model_executor/models/plamo2.py b/vllm/model_executor/models/plamo2.py index e38b7b166ba9..44b1207745ec 100644 --- a/vllm/model_executor/models/plamo2.py +++ b/vllm/model_executor/models/plamo2.py @@ -445,6 +445,8 @@ def forward_impl( dt_softplus=True, state_batch_indices=state_indices_tensor_d, out=preallocated_ssm_out_d.view(num_decodes, -1, self.head_dim), + enable_stochastic_rounding=self.cache_config.enable_mamba_cache_stochastic_rounding, + cache_philox_rounds=self.cache_config.mamba_cache_philox_rounds, ) # 4. Final linear projection From b4a2f3ac369043b4a734160215575f2bc8037678 Mon Sep 17 00:00:00 2001 From: Johnny Date: Mon, 30 Mar 2026 18:36:18 +0200 Subject: [PATCH 0025/1601] [NVIDIA] Bugfix NVFP4 DGX Spark and RTX50 (#38423) Signed-off-by: johnnynunez Signed-off-by: Johnny --- CMakeLists.txt | 2 +- csrc/quantization/fp4/nvfp4_quant_entry.cu | 29 +++++++++++++++++++ .../quantization/fp4/nvfp4_scaled_mm_entry.cu | 14 ++++++++- .../quantization/machete/machete_mainloop.cuh | 1 + docker/Dockerfile | 5 +++- docker/Dockerfile.nightly_torch | 7 +++-- docker/versions.json | 2 +- requirements/cuda.txt | 4 +-- .../moe/test_unquantized_backend_selection.py | 4 --- .../fused_moe/experts/trtllm_fp8_moe.py | 5 ++++ .../fused_moe/experts/trtllm_nvfp4_moe.py | 8 +++++ .../fused_moe/flashinfer_cutlass_moe.py | 2 +- .../layers/fused_moe/oracle/unquantized.py | 6 ++-- .../fused_moe/unquantized_fused_moe_method.py | 1 - .../layers/quantization/utils/nvfp4_utils.py | 16 ++++++++-- 15 files changed, 86 insertions(+), 20 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index cf59f18eb7e7..ff57473cf163 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -309,7 +309,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SET(CUTLASS_ENABLE_HEADERS_ONLY ON CACHE BOOL "Enable only the header library") # Set CUTLASS_REVISION. Used for FetchContent. Also fixes some bogus messages when building. - set(CUTLASS_REVISION "v4.2.1") + set(CUTLASS_REVISION "v4.4.2") # Use the specified CUTLASS source directory for compilation if VLLM_CUTLASS_SRC_DIR is provided if (DEFINED ENV{VLLM_CUTLASS_SRC_DIR}) diff --git a/csrc/quantization/fp4/nvfp4_quant_entry.cu b/csrc/quantization/fp4/nvfp4_quant_entry.cu index 8b5a1fd22cb7..8729b01e4442 100644 --- a/csrc/quantization/fp4/nvfp4_quant_entry.cu +++ b/csrc/quantization/fp4/nvfp4_quant_entry.cu @@ -16,6 +16,7 @@ #include +#include "cutlass_extensions/common.hpp" #include "nvfp4_utils.cuh" #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ @@ -53,12 +54,27 @@ void silu_and_mul_scaled_fp4_experts_quant_sm1xxa( torch::Tensor const& output_scale_offset_by_experts); #endif +static bool nvfp4_quant_sm_supported() { + const int32_t sm = get_sm_version_num(); +#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100 + if (sm >= 100 && sm < 120) return true; +#endif +#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120 + if (sm >= 120 && sm < 130) return true; +#endif + return false; +} + void scaled_fp4_quant_out(torch::Tensor const& input, torch::Tensor const& input_sf, bool is_sf_swizzled_layout, torch::Tensor& output, torch::Tensor& output_sf) { #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + TORCH_CHECK(nvfp4_quant_sm_supported(), + "No compiled nvfp4 quantization kernel for SM ", + get_sm_version_num(), + ". Recompile with the appropriate CUDA arch."); return scaled_fp4_quant_sm1xxa(output, input, output_sf, input_sf, is_sf_swizzled_layout); #endif @@ -100,6 +116,10 @@ void scaled_fp4_experts_quant( torch::Tensor const& output_scale_offset_by_experts) { #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + TORCH_CHECK(nvfp4_quant_sm_supported(), + "No compiled nvfp4 experts quantization kernel for SM ", + get_sm_version_num(), + ". Recompile with the appropriate CUDA arch."); return scaled_fp4_experts_quant_sm1xxa( output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts); @@ -112,6 +132,10 @@ void silu_and_mul_nvfp4_quant(torch::Tensor& output, torch::Tensor& output_sf, torch::Tensor& input, torch::Tensor& input_sf) { #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + TORCH_CHECK(nvfp4_quant_sm_supported(), + "No compiled silu_and_mul nvfp4 quantization kernel for SM ", + get_sm_version_num(), + ". Recompile with the appropriate CUDA arch."); return silu_and_mul_nvfp4_quant_sm1xxa(output, output_sf, input, input_sf); #endif TORCH_CHECK_NOT_IMPLEMENTED( @@ -125,6 +149,11 @@ void silu_and_mul_scaled_fp4_experts_quant( torch::Tensor const& output_scale_offset_by_experts) { #if (defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100) || \ (defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120) + TORCH_CHECK(nvfp4_quant_sm_supported(), + "No compiled silu_and_mul nvfp4 experts quantization kernel " + "for SM ", + get_sm_version_num(), + ". Recompile with the appropriate CUDA arch."); return silu_and_mul_scaled_fp4_experts_quant_sm1xxa( output, output_scale, input, input_global_scale, input_offset_by_experts, output_scale_offset_by_experts); diff --git a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu index d9c4d24d8e1f..4985bf1eb156 100644 --- a/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu +++ b/csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu @@ -63,5 +63,17 @@ void cutlass_scaled_fp4_mm(torch::Tensor& D, const torch::Tensor& A, bool cutlass_scaled_mm_supports_fp4(int64_t cuda_device_capability) { int runtimeVersion; cudaRuntimeGetVersion(&runtimeVersion); - return cuda_device_capability >= 100 && runtimeVersion >= 12080; + if (runtimeVersion < 12080) return false; + // Only report support when the SM-specific kernel was actually compiled in, + // so the Python-side backend selector does not choose CUTLASS and then hit + // TORCH_CHECK_NOT_IMPLEMENTED (or worse, fall through to Marlin). +#if defined(ENABLE_NVFP4_SM100) && ENABLE_NVFP4_SM100 + if (cuda_device_capability >= 100 && cuda_device_capability < 120) + return true; +#endif +#if defined(ENABLE_NVFP4_SM120) && ENABLE_NVFP4_SM120 + if (cuda_device_capability >= 120 && cuda_device_capability < 130) + return true; +#endif + return false; } diff --git a/csrc/quantization/machete/machete_mainloop.cuh b/csrc/quantization/machete/machete_mainloop.cuh index 9f02f4f17974..a0d1ded8876a 100644 --- a/csrc/quantization/machete/machete_mainloop.cuh +++ b/csrc/quantization/machete/machete_mainloop.cuh @@ -154,6 +154,7 @@ struct MacheteCollectiveMma { struct DispatchPolicy { constexpr static int Stages = PipelineStages; using ClusterShape = ClusterShape_MNK; + using ArchTag = arch::Sm90; using Schedule = KernelScheduleType; }; diff --git a/docker/Dockerfile b/docker/Dockerfile index b26b82eb598e..2d63b39dc718 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -590,7 +590,10 @@ RUN --mount=type=cache,target=/root/.cache/uv \ # Install FlashInfer JIT cache (requires CUDA-version-specific index URL) # https://docs.flashinfer.ai/installation.html # From versions.json: .flashinfer.version -ARG FLASHINFER_VERSION=0.6.6 +# 0.6.7: CUTLASS 4.4.2 bump, fixes TMA grouped GEMM on SM12x (flashinfer#2798) +# TODO: bump to 0.6.8 when released for NVFP4/MXFP4 group GEMMs on +# SM120/SM121 (RTX 50 / DGX Spark) via flashinfer#2738 +ARG FLASHINFER_VERSION=0.6.7 RUN --mount=type=cache,target=/root/.cache/uv \ uv pip install --system flashinfer-jit-cache==${FLASHINFER_VERSION} \ --extra-index-url https://flashinfer.ai/whl/cu$(echo $CUDA_VERSION | cut -d. -f1,2 | tr -d '.') \ diff --git a/docker/Dockerfile.nightly_torch b/docker/Dockerfile.nightly_torch index 5c424980ee2d..045a09a42dc4 100644 --- a/docker/Dockerfile.nightly_torch +++ b/docker/Dockerfile.nightly_torch @@ -217,13 +217,16 @@ RUN pip install setuptools==75.6.0 packaging==23.2 ninja==1.11.1.3 build==1.2.2. # build flashinfer for torch nightly from source around 10 mins -# release version: v0.6.6 +# release version: v0.6.7 +# 0.6.7: CUTLASS 4.4.2 bump, fixes TMA grouped GEMM on SM12x (flashinfer#2798) +# TODO: bump to 0.6.8 when released for NVFP4/MXFP4 group GEMMs on +# SM120/SM121 (RTX 50 / DGX Spark) via flashinfer#2738 # todo(elainewy): cache flashinfer build result for faster build ENV CCACHE_DIR=/root/.cache/ccache RUN --mount=type=cache,target=/root/.cache/ccache \ --mount=type=cache,target=/root/.cache/uv \ echo "git clone flashinfer..." \ - && git clone --depth 1 --branch v0.6.6 --recursive https://github.com/flashinfer-ai/flashinfer.git \ + && git clone --depth 1 --branch v0.6.7 --recursive https://github.com/flashinfer-ai/flashinfer.git \ && cd flashinfer \ && git submodule update --init --recursive \ && echo "finish git clone flashinfer..." \ diff --git a/docker/versions.json b/docker/versions.json index 582d1bd54279..51be339546ea 100644 --- a/docker/versions.json +++ b/docker/versions.json @@ -68,7 +68,7 @@ "default": "true" }, "FLASHINFER_VERSION": { - "default": "0.6.6" + "default": "0.6.7" }, "GDRCOPY_CUDA_VERSION": { "default": "12.8" diff --git a/requirements/cuda.txt b/requirements/cuda.txt index fe566db357dc..6d7f9693f75c 100644 --- a/requirements/cuda.txt +++ b/requirements/cuda.txt @@ -9,8 +9,8 @@ torchaudio==2.10.0 # These must be updated alongside torch torchvision==0.25.0 # Required for phi3v processor. See https://github.com/pytorch/vision?tab=readme-ov-file#installation for corresponding version # FlashInfer should be updated together with the Dockerfile -flashinfer-python==0.6.6 -flashinfer-cubin==0.6.6 +flashinfer-python==0.6.7 +flashinfer-cubin==0.6.7 # Cap nvidia-cudnn-frontend (transitive dep of flashinfer) due to # breaking changes in 1.19.0 nvidia-cudnn-frontend>=1.13.0,<1.19.0 diff --git a/tests/kernels/moe/test_unquantized_backend_selection.py b/tests/kernels/moe/test_unquantized_backend_selection.py index 1d9e1d685756..48f1a4b94abb 100644 --- a/tests/kernels/moe/test_unquantized_backend_selection.py +++ b/tests/kernels/moe/test_unquantized_backend_selection.py @@ -57,7 +57,6 @@ def test_select_default_backend_by_platform( moe_config = make_dummy_moe_config() selected_backend = select_unquantized_moe_backend( moe_config=moe_config, - use_ep=False, use_dp=False, ) @@ -90,7 +89,6 @@ def test_select_rocm_aiter_backend(mock_aiter_enabled, mock_has_flashinfer): moe_config = make_dummy_moe_config() selected_backend = select_unquantized_moe_backend( moe_config=moe_config, - use_ep=False, use_dp=False, ) @@ -129,7 +127,6 @@ def test_select_cuda_flashinfer_trtllm_backend( selected_backend = select_unquantized_moe_backend( moe_config=moe_config, - use_ep=True, use_dp=False, ) @@ -171,7 +168,6 @@ def test_select_cuda_flashinfer_cutlass_backend( selected_backend = select_unquantized_moe_backend( moe_config=moe_config, - use_ep=True, # CUTLASS requires EP use_dp=False, # CUTLASS doesn't support DP ) diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py index c0a7dfc4978a..4cb12a8c17f3 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_fp8_moe.py @@ -406,6 +406,11 @@ def _apply_per_tensor( if self.routing_method_type == RoutingMethodType.DeepSeekV3: router_logits = router_logits.to(torch.float32) + # Currently FI requires bfloat16 routing bias. + # https://github.com/flashinfer-ai/flashinfer/issues/2909 + if e_score_correction_bias is not None: + e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16) + out = flashinfer.fused_moe.trtllm_fp8_per_tensor_scale_moe( routing_logits=router_logits, routing_bias=e_score_correction_bias, diff --git a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py index b47391c41ea0..81b778c8f4a7 100644 --- a/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py +++ b/vllm/model_executor/layers/fused_moe/experts/trtllm_nvfp4_moe.py @@ -5,6 +5,7 @@ import torch import vllm.model_executor.layers.fused_moe.modular_kernel as mk +from vllm.logger import init_logger from vllm.model_executor.layers.fused_moe.activation import MoEActivation from vllm.model_executor.layers.fused_moe.config import ( FusedMoEConfig, @@ -27,6 +28,8 @@ from vllm.platforms import current_platform from vllm.utils.flashinfer import has_flashinfer_trtllm_fused_moe +logger = init_logger(__name__) + class TrtLlmNvFp4ExpertsBase: """ @@ -315,6 +318,11 @@ def apply( else router_logits ) + # Currently FI requires bfloat16 routing bias. + # https://github.com/flashinfer-ai/flashinfer/issues/2909 + if e_score_correction_bias is not None: + e_score_correction_bias = e_score_correction_bias.to(torch.bfloat16) + # Invoke kernel. return flashinfer.fused_moe.trtllm_fp4_block_scale_moe( routing_logits=router_logits, diff --git a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py index 91f7a83f6fce..26409804c48d 100644 --- a/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +++ b/vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py @@ -361,7 +361,7 @@ def apply( fc1_expert_weights = w1 fc2_expert_weights = w2 else: - quant_scales = None + quant_scales = [] a1q_scale = None fc1_expert_weights = w1 fc2_expert_weights = w2 diff --git a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py index 9c31da10dd94..84f4df010673 100644 --- a/vllm/model_executor/layers/fused_moe/oracle/unquantized.py +++ b/vllm/model_executor/layers/fused_moe/oracle/unquantized.py @@ -70,7 +70,6 @@ def map_unquantized_backend(runner_backend: MoEBackend) -> UnquantizedMoeBackend def select_unquantized_moe_backend( moe_config: FusedMoEConfig, - use_ep: bool, use_dp: bool, ) -> UnquantizedMoeBackend: """ @@ -96,7 +95,6 @@ def _make_log_backend(backend: UnquantizedMoeBackend): # FlashInfer CUTLASS MoE is only supported on Hopper and later GPUS flashinfer_cutlass_available = ( has_flashinfer_cutlass_fused_moe() - and use_ep and (not use_dp) and current_platform.has_device_capability(90) ) @@ -161,9 +159,9 @@ def _make_log_backend(backend: UnquantizedMoeBackend): "to enable it for better performance.", scope="local", ) - elif use_ep and (not use_dp): + elif not use_dp and flashinfer_cutlass_available: logger.info_once( - "FlashInfer MoE is available for EP" + "FlashInfer CUTLASS MoE is available" " but not enabled, consider setting" " VLLM_USE_FLASHINFER_MOE_FP16=1 to enable it.", scope="local", diff --git a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py index 38b552b021ba..cf1afcb29b21 100644 --- a/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +++ b/vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py @@ -61,7 +61,6 @@ def __init__(self, moe: FusedMoEConfig): super().__init__(moe) self.unquantized_backend = select_unquantized_moe_backend( moe_config=self.moe, - use_ep=self.moe.moe_parallel_config.use_ep, use_dp=self.moe.moe_parallel_config.dp_size > 1, ) diff --git a/vllm/model_executor/layers/quantization/utils/nvfp4_utils.py b/vllm/model_executor/layers/quantization/utils/nvfp4_utils.py index bcb4769e4c9b..f21f2ef23f4d 100644 --- a/vllm/model_executor/layers/quantization/utils/nvfp4_utils.py +++ b/vllm/model_executor/layers/quantization/utils/nvfp4_utils.py @@ -55,8 +55,16 @@ def select_nvfp4_linear_backend() -> NvFp4LinearBackend: elif envs.VLLM_USE_NVFP4_CT_EMULATIONS: backend = NvFp4LinearBackend.EMULATION elif envs.VLLM_NVFP4_GEMM_BACKEND is None: - # Auto-select best available backend - if current_platform.has_device_capability(100) and has_flashinfer(): + # Auto-select best available backend. + # cutlass_fp4_supported() checks that the vLLM NVFP4 kernels (both + # quantization and GEMM) were compiled for the current SM version. + # FlashInfer backends still rely on the vLLM quantization kernels, + # so we gate them on the same check. + if ( + cutlass_fp4_supported() + and current_platform.has_device_capability(100) + and has_flashinfer() + ): backend = NvFp4LinearBackend.FLASHINFER_CUTLASS elif cutlass_fp4_supported(): backend = NvFp4LinearBackend.VLLM_CUTLASS @@ -72,6 +80,10 @@ def select_nvfp4_linear_backend() -> NvFp4LinearBackend: NvFp4LinearBackend.FLASHINFER_CUDNN, ): assert has_flashinfer(), f"FlashInfer is required for {backend}" + assert cutlass_fp4_supported(), ( + f"{backend} requires vLLM NVFP4 quantization kernels compiled " + f"for the current GPU (SM {current_platform.get_device_capability()})" + ) elif backend == NvFp4LinearBackend.VLLM_CUTLASS: assert cutlass_fp4_supported(), f"Cutlass is required for {backend}" elif backend == NvFp4LinearBackend.MARLIN: From 3b1dbaad4e59742ab5a5e4fd8f4ccc4f0e83c996 Mon Sep 17 00:00:00 2001 From: "Chendi.Xue" Date: Mon, 30 Mar 2026 11:47:30 -0500 Subject: [PATCH 0026/1601] [HMA]Fix corner case when hybrid page_size can not be evenly divided issue (blk_size=64,tp=4) (#37467) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: Chendi Xue Signed-off-by: Matthew Bonanni Signed-off-by: Chendi.Xue Co-authored-by: Matthew Bonanni Co-authored-by: Nicolò Lucchesi --- tests/v1/worker/test_gpu_model_runner.py | 2 + vllm/config/cache.py | 5 + vllm/model_executor/models/config.py | 149 +------------ vllm/platforms/interface.py | 197 +++++++++++++++--- vllm/platforms/xpu.py | 10 - vllm/v1/attention/backend.py | 4 + vllm/v1/attention/backends/flash_attn.py | 7 + vllm/v1/attention/backends/gdn_attn.py | 4 + vllm/v1/attention/backends/linear_attn.py | 4 + vllm/v1/attention/backends/mamba1_attn.py | 4 + vllm/v1/attention/backends/mamba2_attn.py | 4 + vllm/v1/attention/backends/short_conv_attn.py | 4 + 12 files changed, 214 insertions(+), 180 deletions(-) diff --git a/tests/v1/worker/test_gpu_model_runner.py b/tests/v1/worker/test_gpu_model_runner.py index 0fbd6605a7fe..d7695027a284 100644 --- a/tests/v1/worker/test_gpu_model_runner.py +++ b/tests/v1/worker/test_gpu_model_runner.py @@ -851,6 +851,7 @@ def test_hybrid_attention_mamba_tensor_shapes(): vllm_ctx = vllm_config.compilation_config.static_forward_context runner = GPUModelRunner(vllm_config, DEVICE) + current_platform.update_block_size_for_backend(vllm_config) kv_cache_spec = runner.get_kv_cache_spec() available_memory = 5 * GiB_bytes @@ -1306,6 +1307,7 @@ def test_mamba_cache_raises_when_max_num_seqs_exceeds_blocks(): assert fwd_context is not None runner = GPUModelRunner(vllm_config, DEVICE) + current_platform.update_block_size_for_backend(vllm_config) kv_cache_spec = runner.get_kv_cache_spec() available_memory = 5 * GiB_bytes diff --git a/vllm/config/cache.py b/vllm/config/cache.py index 1cadb43183d2..49c8868e709f 100644 --- a/vllm/config/cache.py +++ b/vllm/config/cache.py @@ -38,6 +38,8 @@ class CacheConfig: Accepts None (meaning "use default"). After construction, always int.""" user_specified_block_size: bool = field(default=False, init=False) """Whether block_size was explicitly provided. Derived automatically.""" + user_specified_mamba_block_size: bool = field(default=False, init=False) + """Whether mamba_block_size was explicitly provided. Derived automatically.""" gpu_memory_utilization: float = Field(default=0.9, gt=0, le=1) """The fraction of GPU memory to be used for the model executor, which can range from 0 to 1. For example, a value of 0.5 would imply 50% GPU memory @@ -182,6 +184,7 @@ def compute_hash(self) -> str: "cpu_kvcache_space_bytes", "mamba_page_size_padded", "user_specified_block_size", + "user_specified_mamba_block_size", "_block_size_resolved", # Post-init/derived counters "num_gpu_blocks", @@ -214,6 +217,8 @@ def _apply_block_size_default(self) -> "CacheConfig": object.__setattr__(self, "block_size", self.DEFAULT_BLOCK_SIZE) else: object.__setattr__(self, "user_specified_block_size", True) + if self.mamba_block_size is not None: + object.__setattr__(self, "user_specified_mamba_block_size", True) return self @field_validator("calculate_kv_scales", mode="after") diff --git a/vllm/model_executor/models/config.py b/vllm/model_executor/models/config.py index a5644a414aee..03b147e5c257 100644 --- a/vllm/model_executor/models/config.py +++ b/vllm/model_executor/models/config.py @@ -1,15 +1,10 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project from copy import deepcopy -from math import lcm from typing import TYPE_CHECKING from vllm.logger import init_logger -from vllm.model_executor.models import ModelRegistry -from vllm.utils.math_utils import cdiv, round_up -from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE -from vllm.v1.attention.backends.registry import AttentionBackendEnum -from vllm.v1.kv_cache_interface import FullAttentionSpec, MambaSpec, MLAAttentionSpec +from vllm.utils.math_utils import round_up if TYPE_CHECKING: from vllm.config import ModelConfig, VllmConfig @@ -104,11 +99,11 @@ class HybridAttentionMambaModelConfig(VerifyAndUpdateConfig): @classmethod def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: """ - Ensure that page size of attention layers is greater than or - equal to the mamba layers. If not, automatically set the attention - block size to ensure that it is. If the attention page size is - strictly greater than the mamba page size, we pad the mamba page size - to make them equal. + Perform early validation and setup for hybrid attention/mamba models. + + Block size alignment with mamba page sizes is handled later by + Platform.update_block_size_for_backend(), which runs after model + layers are constructed and the attention backend is known. Args: vllm_config: vLLM Config @@ -118,6 +113,7 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: # Disable calculate_kv_scales for hybrid models: uninitialized # recurrent state corrupts scales during the calibration pass. # See issue: https://github.com/vllm-project/vllm/issues/37554 + if cache_config.calculate_kv_scales: logger.warning( "Disabling calculate_kv_scales for hybrid model '%s'. " @@ -129,140 +125,9 @@ def verify_and_update_config(cls, vllm_config: "VllmConfig") -> None: ) cache_config.calculate_kv_scales = False - # Save the user input before it gets modified by MambaModelConfig - mamba_block_size = cache_config.mamba_block_size # Enable FULL_AND_PIECEWISE by default MambaModelConfig.verify_and_update_config(vllm_config) - attention_config = vllm_config.attention_config - cache_config = vllm_config.cache_config - model_config = vllm_config.model_config - parallel_config = vllm_config.parallel_config - - if cache_config.cache_dtype == "auto": - kv_cache_dtype = model_config.dtype - else: - kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] - - # get attention page size (for 1 token) - # Attention backend constraints: - # - FlashAttention (FA) requires block size to be multiple of 16 - # - MLA (Multi-head Latent Attention) requires larger alignment: - # * CUTLASS_MLA backend: kernel_block_size 128 alignment - # * Other MLA backends: kernel_block_size 64 alignment - if model_config.use_mla: - use_cutlass_mla = ( - attention_config.backend == AttentionBackendEnum.CUTLASS_MLA - ) - kernel_block_alignment_size = 128 if use_cutlass_mla else 64 - attn_page_size_1_token = MLAAttentionSpec( - block_size=1, - num_kv_heads=model_config.get_num_kv_heads(parallel_config), - head_size=model_config.get_head_size(), - dtype=kv_cache_dtype, - ).page_size_bytes - else: - kernel_block_alignment_size = 16 - attn_page_size_1_token = FullAttentionSpec( - block_size=1, - num_kv_heads=model_config.get_num_kv_heads(parallel_config), - head_size=model_config.get_head_size(), - dtype=kv_cache_dtype, - ).page_size_bytes - - model_cls, _ = ModelRegistry.resolve_model_cls( - model_config.architecture, - model_config=model_config, - ) - - # get mamba page size - mamba_page_size = MambaSpec( - shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), - dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), - block_size=-1, # block_size doesn't matter for mamba page size - ).page_size_bytes - - # Model may be marked as is_hybrid - # but mamba is skipped via config, - # return directly - if mamba_page_size == 0: - return - - if cache_config.mamba_cache_mode == "all": - # With prefix caching, select attention block size to - # optimize for mamba kernel performance - - # Mamba2 SSD kernel uses a chunk_size, e.g. 256 - # Align the block to the kernel: use lowest multiple of chunk_size - # of attention tokens that would fit mamba_page_size: - # e.g. for mamba page size = 788kB - # attn_1_token = 2kB -> fits ~394 tokens - # then round up to a multiple of 256 -> 512 tokens - # End result: - # attn_block_size = 512 - # mamba_block_size = 512 (aligned to a multiple of chunk_size) - # TODO(tdoublep): this constraint can be relaxed fairly - # easily by changing the way we layout chunks in the - # mamba2 kernels. - - base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size() - attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) - chunk_size = lcm(base_chunk_size, kernel_block_alignment_size) - attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) - cache_config.mamba_block_size = attn_block_size - else: - # Without prefix caching, select minimum valid attention block size - # to minimize mamba state padding - - # Calculate minimum attention block size that satisfies both: - # 1. Backend alignment requirements (kernel_block_alignment_size) - # 2. Mamba page size compatibility (attn_page_size >= mamba_page_size) - attn_block_size = kernel_block_alignment_size * cdiv( - mamba_page_size, kernel_block_alignment_size * attn_page_size_1_token - ) - - # override attention block size if it is too small, - # even if the user has explicitly set it - if cache_config.block_size < attn_block_size: - cache_config.block_size = attn_block_size - logger.info( - "Setting attention block size to %d tokens " - "to ensure that attention page size is >= mamba page size.", - attn_block_size, - ) - - # By default, mamba block size will be set to max_model_len. - # When enabling prefix caching and using align mamba cache - # mode, we align mamba block size to the block size as the - # basic granularity for prefix caching. - if cache_config.mamba_cache_mode == "align": - cache_config.mamba_block_size = cache_config.block_size - - # compute new attention page size - attn_page_size = cache_config.block_size * attn_page_size_1_token - - assert attn_page_size >= mamba_page_size - - if attn_page_size == mamba_page_size: - # don't need to pad mamba page size - return - - # pad mamba page size to exactly match attention - if ( - cache_config.mamba_page_size_padded is None - or cache_config.mamba_page_size_padded != attn_page_size - ): - cache_config.mamba_page_size_padded = attn_page_size - mamba_padding_pct = ( - 100 * (attn_page_size - mamba_page_size) / mamba_page_size - ) - logger.info( - "Padding mamba page size by %.2f%% to ensure " - "that mamba page size and attention page size are " - "exactly equal.", - mamba_padding_pct, - ) - class JambaForSequenceClassificationConfig(VerifyAndUpdateConfig): @staticmethod diff --git a/vllm/platforms/interface.py b/vllm/platforms/interface.py index 281e91999ef3..fae37442ec57 100644 --- a/vllm/platforms/interface.py +++ b/vllm/platforms/interface.py @@ -21,6 +21,7 @@ from vllm.pooling_params import PoolingParams from vllm.sampling_params import SamplingParams from vllm.utils.argparse_utils import FlexibleArgumentParser + from vllm.v1.attention.backend import AttentionBackend from vllm.v1.attention.selector import AttentionSelectorConfig else: FlexibleArgumentParser = object @@ -423,55 +424,195 @@ def check_and_update_config(cls, vllm_config: "VllmConfig") -> None: """ pass + @classmethod + def _find_non_ssm_backend( + cls, vllm_config: "VllmConfig" + ) -> "type[AttentionBackend] | None": + """Find the first non-SSM attention backend from model layers.""" + from vllm.config.vllm import get_layers_from_vllm_config + from vllm.model_executor.layers.attention_layer_base import ( + AttentionLayerBase, + ) + + attn_layers = get_layers_from_vllm_config( + vllm_config, + AttentionLayerBase, # type: ignore[type-abstract] + ) + for layer in attn_layers.values(): + b = layer.get_attn_backend() + if not b.is_ssm(): + return b + return None + @classmethod def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: """ Ensure block_size is compatible with the attention backend. + For hybrid models, also aligns block_size with mamba page sizes. """ from vllm.config.cache import CacheConfig + from vllm.config.vllm import set_current_vllm_config cache_config = vllm_config.cache_config - if cache_config.user_specified_block_size: - # User specified --block-size; keep it. - return - model_config = vllm_config.model_config + # model_config may be None during testing. - # Skip hybrid models — their block_size is managed by - # HybridAttentionMambaModelConfig. - if model_config is None or model_config.is_hybrid: - cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE + if not model_config: return - from vllm.config.vllm import ( - get_layers_from_vllm_config, - set_current_vllm_config, - ) - from vllm.model_executor.layers.attention_layer_base import ( - AttentionLayerBase, + backend_cls = cls._find_non_ssm_backend(vllm_config) + if backend_cls is None: + return + + # Phase 1: Pick block size from backend (skip if user set --block-size) + if not cache_config.user_specified_block_size: + with set_current_vllm_config(vllm_config): + preferred = backend_cls.get_preferred_block_size( + CacheConfig.DEFAULT_BLOCK_SIZE + ) + if preferred != CacheConfig.DEFAULT_BLOCK_SIZE: + logger.info( + "Setting kv cache block size to %d for %s backend.", + preferred, + backend_cls.get_name(), + ) + cache_config.block_size = preferred + + # Phase 2: Align block/mamba sizes for hybrid models + # (may override user settings). + if model_config.is_hybrid: + cls._align_hybrid_block_size(vllm_config, backend_cls) + + @classmethod + def _align_hybrid_block_size( + cls, + vllm_config: "VllmConfig", + backend_cls: "type[AttentionBackend]", + ) -> None: + """ + For hybrid attention/mamba models, ensure that the attention page + size is >= the mamba page size, and pad the mamba page size to match. + """ + from math import lcm + + from vllm.config.vllm import set_current_vllm_config + from vllm.model_executor.models import ModelRegistry + from vllm.utils.math_utils import cdiv + from vllm.utils.torch_utils import STR_DTYPE_TO_TORCH_DTYPE + from vllm.v1.attention.backend import MultipleOf + from vllm.v1.kv_cache_interface import ( + FullAttentionSpec, + MambaSpec, + MLAAttentionSpec, ) - attn_layers = get_layers_from_vllm_config( - vllm_config, - AttentionLayerBase, # type: ignore[type-abstract] + cache_config = vllm_config.cache_config + model_config = vllm_config.model_config + parallel_config = vllm_config.parallel_config + + if cache_config.cache_dtype == "auto": + kv_cache_dtype = model_config.dtype + else: + kv_cache_dtype = STR_DTYPE_TO_TORCH_DTYPE[cache_config.cache_dtype] + + # Compute attention page size for 1 token + if model_config.use_mla: + attn_page_size_1_token = MLAAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + ).page_size_bytes + else: + attn_page_size_1_token = FullAttentionSpec( + block_size=1, + num_kv_heads=model_config.get_num_kv_heads(parallel_config), + head_size=model_config.get_head_size(), + dtype=kv_cache_dtype, + ).page_size_bytes + + # Compute mamba page size + model_cls, _ = ModelRegistry.resolve_model_cls( + model_config.architecture, + model_config=model_config, ) - if not attn_layers: - cache_config.block_size = CacheConfig.DEFAULT_BLOCK_SIZE + mamba_page_size = MambaSpec( + shapes=model_cls.get_mamba_state_shape_from_config(vllm_config), + dtypes=model_cls.get_mamba_state_dtype_from_config(vllm_config), + block_size=-1, + ).page_size_bytes + + if mamba_page_size == 0: return - first_layer = next(iter(attn_layers.values())) - backend_cls = first_layer.get_attn_backend() + # mamba_block_size here should either be user specified value or None + mamba_block_size = ( + cache_config.mamba_block_size + if cache_config.user_specified_mamba_block_size + else None + ) + + # Get kernel block alignment from the backend's supported sizes with set_current_vllm_config(vllm_config): - preferred = backend_cls.get_preferred_block_size( - CacheConfig.DEFAULT_BLOCK_SIZE + kernel_block_alignment_size = max( + min( + s.base if isinstance(s, MultipleOf) else s + for s in backend_cls.get_supported_kernel_block_sizes() + ), + cache_config.block_size, + ) + + if cache_config.mamba_cache_mode == "all": + # With prefix caching, align to mamba chunk size for kernel perf + # TODO(tdoublep): this constraint can be relaxed fairly + # easily by changing the way we layout chunks in the + # mamba2 kernels. + base_chunk_size = mamba_block_size or model_config.get_mamba_chunk_size() + assert base_chunk_size is not None + attn_tokens_per_mamba_state = cdiv(mamba_page_size, attn_page_size_1_token) + chunk_size = lcm(base_chunk_size, kernel_block_alignment_size) + attn_block_size = chunk_size * cdiv(attn_tokens_per_mamba_state, chunk_size) + cache_config.mamba_block_size = attn_block_size + else: + # Without prefix caching, use minimum block size that satisfies + # both backend alignment and mamba page size compatibility + attn_block_size = kernel_block_alignment_size * cdiv( + mamba_page_size, + kernel_block_alignment_size * attn_page_size_1_token, + ) + + if cache_config.block_size < attn_block_size: + cache_config.block_size = attn_block_size + logger.info( + "Setting attention block size to %d tokens " + "to ensure that attention page size is >= mamba page size.", + attn_block_size, + ) + + if cache_config.mamba_cache_mode == "align": + cache_config.mamba_block_size = cache_config.block_size + + # Pad mamba page size to exactly match attention page size + attn_page_size = cache_config.block_size * attn_page_size_1_token + assert attn_page_size >= mamba_page_size + + if attn_page_size == mamba_page_size: + return + + if ( + cache_config.mamba_page_size_padded is None + or cache_config.mamba_page_size_padded != attn_page_size + ): + cache_config.mamba_page_size_padded = attn_page_size + mamba_padding_pct = ( + 100 * (attn_page_size - mamba_page_size) / mamba_page_size ) - if preferred != CacheConfig.DEFAULT_BLOCK_SIZE: logger.info( - "Setting kv cache block size to %d for %s backend.", - preferred, - backend_cls.get_name(), + "Padding mamba page size by %.2f%% to ensure " + "that mamba page size and attention page size are " + "exactly equal.", + mamba_padding_pct, ) - cache_config.block_size = preferred @classmethod def verify_model_arch(cls, model_arch: str) -> None: diff --git a/vllm/platforms/xpu.py b/vllm/platforms/xpu.py index b8cab5f45dcd..bf96b94af377 100644 --- a/vllm/platforms/xpu.py +++ b/vllm/platforms/xpu.py @@ -160,11 +160,7 @@ def get_static_graph_wrapper_cls(cls) -> str: @classmethod def check_and_update_config(cls, vllm_config: VllmConfig) -> None: - cache_config = vllm_config.cache_config parallel_config = vllm_config.parallel_config - # in V1(or with chunked prefill) block_size is 64 - if cache_config and not cache_config.user_specified_block_size: - cache_config.block_size = 64 # lazy import to avoid circular import from vllm.config import CUDAGraphMode @@ -221,12 +217,6 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None: # ref. https://openucx.readthedocs.io/en/master/faq.html os.environ["UCX_MEMTYPE_CACHE"] = "n" - @classmethod - def update_block_size_for_backend(cls, vllm_config: "VllmConfig") -> None: - # TODO: XPU still sets block_size in check_and_update_config. - # Move that logic here so block_size is chosen by the backend. - pass - @classmethod def support_hybrid_kv_cache(cls) -> bool: return True diff --git a/vllm/v1/attention/backend.py b/vllm/v1/attention/backend.py index cd49ea30e6f4..9001b23f3d54 100644 --- a/vllm/v1/attention/backend.py +++ b/vllm/v1/attention/backend.py @@ -311,6 +311,10 @@ def validate_configuration( def get_required_kv_cache_layout(cls) -> "KVCacheLayoutType | None": return None + @classmethod + def is_ssm(cls) -> bool: + return False + class AttentionMetadata: pass diff --git a/vllm/v1/attention/backends/flash_attn.py b/vllm/v1/attention/backends/flash_attn.py index 245995be2642..5e63fa5921d3 100755 --- a/vllm/v1/attention/backends/flash_attn.py +++ b/vllm/v1/attention/backends/flash_attn.py @@ -43,6 +43,7 @@ from vllm.config.cache import CacheDType from vllm.distributed.parallel_state import get_dcp_group from vllm.logger import init_logger +from vllm.platforms import current_platform from vllm.platforms.interface import DeviceCapability from vllm.utils.math_utils import cdiv, round_up from vllm.v1.attention.backend import ( @@ -90,6 +91,12 @@ def get_supported_kernel_block_sizes() -> list[int | MultipleOf]: forward_includes_kv_cache_update: bool = False + @classmethod + def get_preferred_block_size(cls, default_block_size: int) -> int: + if current_platform.is_xpu(): + return max(default_block_size, 64) + return super().get_preferred_block_size(default_block_size) + @staticmethod def get_name() -> str: return "FLASH_ATTN" diff --git a/vllm/v1/attention/backends/gdn_attn.py b/vllm/v1/attention/backends/gdn_attn.py index 574cc87e7582..f65d9a4b3891 100644 --- a/vllm/v1/attention/backends/gdn_attn.py +++ b/vllm/v1/attention/backends/gdn_attn.py @@ -31,6 +31,10 @@ def get_name() -> str: def get_builder_cls() -> type["GDNAttentionMetadataBuilder"]: return GDNAttentionMetadataBuilder + @classmethod + def is_ssm(cls) -> bool: + return True + @dataclass class GDNAttentionMetadata: diff --git a/vllm/v1/attention/backends/linear_attn.py b/vllm/v1/attention/backends/linear_attn.py index fe27e7a389ac..b2ca151986cc 100644 --- a/vllm/v1/attention/backends/linear_attn.py +++ b/vllm/v1/attention/backends/linear_attn.py @@ -27,6 +27,10 @@ def get_name() -> str: def get_builder_cls() -> type["LinearAttentionMetadataBuilder"]: return LinearAttentionMetadataBuilder + @classmethod + def is_ssm(cls) -> bool: + return True + @dataclass class LinearAttentionMetadata: diff --git a/vllm/v1/attention/backends/mamba1_attn.py b/vllm/v1/attention/backends/mamba1_attn.py index 8903406200ca..925fceb024f6 100644 --- a/vllm/v1/attention/backends/mamba1_attn.py +++ b/vllm/v1/attention/backends/mamba1_attn.py @@ -20,6 +20,10 @@ def get_name() -> str: def get_builder_cls() -> type["Mamba1AttentionMetadataBuilder"]: return Mamba1AttentionMetadataBuilder + @classmethod + def is_ssm(cls) -> bool: + return True + @dataclass class Mamba1AttentionMetadata(BaseMambaAttentionMetadata): diff --git a/vllm/v1/attention/backends/mamba2_attn.py b/vllm/v1/attention/backends/mamba2_attn.py index 5e8abbab565e..fa7d4bd2ec51 100644 --- a/vllm/v1/attention/backends/mamba2_attn.py +++ b/vllm/v1/attention/backends/mamba2_attn.py @@ -96,6 +96,10 @@ def get_name() -> str: def get_builder_cls() -> type["Mamba2AttentionMetadataBuilder"]: return Mamba2AttentionMetadataBuilder + @classmethod + def is_ssm(cls) -> bool: + return True + @dataclass class Mamba2AttentionMetadata(BaseMambaAttentionMetadata): diff --git a/vllm/v1/attention/backends/short_conv_attn.py b/vllm/v1/attention/backends/short_conv_attn.py index c6a8e6eeaa16..9c85ec5efb30 100644 --- a/vllm/v1/attention/backends/short_conv_attn.py +++ b/vllm/v1/attention/backends/short_conv_attn.py @@ -18,6 +18,10 @@ def get_name() -> str: def get_builder_cls() -> type["ShortConvAttentionMetadataBuilder"]: return ShortConvAttentionMetadataBuilder + @classmethod + def is_ssm(cls) -> bool: + return True + @dataclass class ShortConvAttentionMetadata(BaseMambaAttentionMetadata): From 2c734ed0e06a48808522fe8f59f6b4ffe0cf0397 Mon Sep 17 00:00:00 2001 From: Matthew Bonanni Date: Mon, 30 Mar 2026 12:51:24 -0400 Subject: [PATCH 0027/1601] [Bugfix][MLA] Change default SM100 MLA prefill backend back to TRT-LLM (#38562) Signed-off-by: Matthew Bonanni --- vllm/config/attention.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/config/attention.py b/vllm/config/attention.py index 1da647a6d6ff..014bb9b22601 100644 --- a/vllm/config/attention.py +++ b/vllm/config/attention.py @@ -30,7 +30,7 @@ class AttentionConfig: use_cudnn_prefill: bool = False """Whether to use cudnn prefill.""" - use_trtllm_ragged_deepseek_prefill: bool = False + use_trtllm_ragged_deepseek_prefill: bool = True """Whether to use TRTLLM ragged deepseek prefill.""" use_trtllm_attention: bool | None = None From b5e608258e7b5e4abadf84ffee36e584d7e00b7d Mon Sep 17 00:00:00 2001 From: fangyuchu Date: Tue, 31 Mar 2026 01:16:09 +0800 Subject: [PATCH 0028/1601] [Refactor] Unify engine process monitoring in engine manager and add Ray backend support (#35862) Signed-off-by: fangyuchu Signed-off-by: Nick Hill Co-authored-by: Nick Hill --- vllm/entrypoints/cli/serve.py | 2 +- vllm/v1/engine/core_client.py | 24 +++------- vllm/v1/engine/utils.py | 87 ++++++++++++++++++++++++++++++++--- vllm/v1/utils.py | 40 +++++++++------- 4 files changed, 111 insertions(+), 42 deletions(-) diff --git a/vllm/entrypoints/cli/serve.py b/vllm/entrypoints/cli/serve.py index 1ece3e4df18c..e265a088a16c 100644 --- a/vllm/entrypoints/cli/serve.py +++ b/vllm/entrypoints/cli/serve.py @@ -225,7 +225,7 @@ def signal_handler(signum, frame): ) try: - engine_manager.join_first() + engine_manager.monitor_engine_liveness() finally: timeout = None if shutdown_requested: diff --git a/vllm/v1/engine/core_client.py b/vllm/v1/engine/core_client.py index b9a3c7545e16..1d73c12eda29 100644 --- a/vllm/v1/engine/core_client.py +++ b/vllm/v1/engine/core_client.py @@ -2,7 +2,6 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import asyncio import contextlib -import multiprocessing import queue import sys import uuid @@ -640,34 +639,20 @@ def dp_engines_running(self) -> bool: def start_engine_core_monitor(self): """Start a monitor thread for engine core processes.""" engine_manager = self.resources.engine_manager - if ( - engine_manager is None - or not hasattr(engine_manager, "processes") - or not engine_manager.processes - ): + if engine_manager is None: # No engine processes to monitor return - engine_processes = engine_manager.processes self_ref = weakref.ref(self) # Monitor engine core process liveness. If any die unexpectedly, - # logs an error, shuts down the client and invokes the failure - # callback to inform the engine. + # marks the engine as dead, and shuts down the client. def monitor_engine_cores(): - sentinels = [proc.sentinel for proc in engine_processes] - died = multiprocessing.connection.wait(sentinels) + engine_manager.monitor_engine_liveness() _self = self_ref() if not _self or not _self._finalizer.alive or _self.resources.engine_dead: return _self.resources.engine_dead = True - proc_name = next( - proc.name for proc in engine_processes if proc.sentinel == died[0] - ) - logger.error( - "Engine core proc %s died unexpectedly, shutting down client.", - proc_name, - ) _self.shutdown() # Note: For MPClient, we don't have a failure callback mechanism # like MultiprocExecutor, but we set engine_dead flag which will @@ -1634,6 +1619,9 @@ async def _scale_down_elastic_ep( parallel_config = self.vllm_config.parallel_config ip, coord_store_port = self._setup_elastic_ep_reconfig_bootstrap() + removed_dp_size = cur_data_parallel_size - new_data_parallel_size + assert isinstance(self.resources.engine_manager, CoreEngineActorManager) + self.resources.engine_manager.remove_run_refs_for_scale_down(removed_dp_size) reconfig_futures = [] for cur_dp_rank, engine in enumerate(self.core_engines): reconfig_request = ReconfigureDistributedRequest( diff --git a/vllm/v1/engine/utils.py b/vllm/v1/engine/utils.py index 90ec47edb033..0ce0ed88e414 100644 --- a/vllm/v1/engine/utils.py +++ b/vllm/v1/engine/utils.py @@ -11,7 +11,7 @@ from multiprocessing import Process, connection from multiprocessing.process import BaseProcess from multiprocessing.queues import Queue -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, cast from unittest.mock import patch import msgspec @@ -133,6 +133,8 @@ def __init__( ) self._finalizer = weakref.finalize(self, shutdown, self.processes) + self.manager_stopped = threading.Event() + self.failed_proc_name: str | None = None try: for proc, local_dp_rank in zip(self.processes, local_dp_ranks): @@ -154,12 +156,31 @@ def __init__( def shutdown(self, timeout: float | None = None) -> None: """Shutdown engine core processes with configurable timeout.""" + self.manager_stopped.set() if self._finalizer.detach() is not None: shutdown(self.processes, timeout=timeout) - def join_first(self): - """Wait for any process to exit.""" - connection.wait(proc.sentinel for proc in self.processes) + def monitor_engine_liveness(self) -> None: + """Monitor engine core process liveness.""" + + sentinel_to_proc = {proc.sentinel: proc for proc in self.processes} + sentinels = set(sentinel_to_proc.keys()) + + while sentinels and not self.manager_stopped.is_set(): + died_sentinels = connection.wait(sentinels, timeout=1) + + for sentinel in died_sentinels: + proc = sentinel_to_proc.pop(cast(int, sentinel)) + exitcode = proc.exitcode + if exitcode != 0 and not self.manager_stopped.is_set(): + self.failed_proc_name = proc.name + if died_sentinels: + # Any engine exit currently triggers a shutdown. Future + # work (e.g., Elastic and fault-tolerant EP) will add finer-grained + # handling for different exit scenarios. + break + + self.shutdown() def sentinels(self) -> list: return [proc.sentinel for proc in self.processes] @@ -298,6 +319,8 @@ def __init__( self.log_stats = log_stats local_engine_count = vllm_config.parallel_config.data_parallel_size_local world_size = vllm_config.parallel_config.world_size + self.manager_stopped = threading.Event() + self.failed_proc_name: str | None = None if ray.is_initialized(): logger.info("Ray is already initialized. Skipping Ray initialization.") @@ -395,8 +418,11 @@ def __init__( ray.get(refs) self.run_refs = [] + self.actor_run_ref_dict = dict() for actor in self.local_engine_actors + self.remote_engine_actors: - self.run_refs.append(actor.run.remote()) + ref = actor.run.remote() + self.run_refs.append(ref) + self.actor_run_ref_dict[actor] = ref @staticmethod def create_dp_placement_groups( @@ -776,7 +802,9 @@ def scale_up_elastic_ep( ) + self.remote_engine_actors[-(len(placement_groups) - new_local_engines) :] for actor in actors: - self.run_refs.append(actor.run.remote()) + ref = actor.run.remote() + self.run_refs.append(ref) + self.actor_run_ref_dict[actor] = ref cur_vllm_config.parallel_config.data_parallel_size = new_data_parallel_size # Update old_vllm_config with new data_parallel_size_local if any new @@ -805,12 +833,59 @@ def scale_down_elastic_ep( self.remote_engine_actors.pop() ray.util.remove_placement_group(pg) + def remove_run_refs_for_scale_down(self, removed_dp_size: int) -> None: + if removed_dp_size <= 0: + return + flags = self.placement_group_is_local[-removed_dp_size:] + li = len(self.local_engine_actors) - 1 + ri = len(self.remote_engine_actors) - 1 + for is_local in reversed(flags): + if is_local: + actor = self.local_engine_actors[li] + li -= 1 + else: + actor = self.remote_engine_actors[ri] + ri -= 1 + ref = self.actor_run_ref_dict.pop(actor) + self.run_refs.remove(ref) + def get_run_refs(self): return self.run_refs + def monitor_engine_liveness(self) -> None: + import ray + + while not self.manager_stopped.is_set(): + actor_run_refs = list(self.get_run_refs()) + if not actor_run_refs: + logger.info( + "There are no actors to monitor currently. " + "The monitoring function is about to terminate." + ) + break + actor_done_refs, _ = ray.wait(actor_run_refs, timeout=5) + unexpected_failure = False + for actor_ref in actor_done_refs: + if self.manager_stopped.is_set(): + break + if actor_ref not in self.get_run_refs(): + # The run refs may have been updated by elastic scale-down. + continue + try: + ray.get(actor_ref) + except ray.exceptions.RayActorError: + self.failed_proc_name = f"Actor {actor_ref}" + unexpected_failure = True + + if unexpected_failure: + break + + self.shutdown() + def shutdown(self, timeout: float | None = None) -> None: import ray + self.manager_stopped.set() for actor in self.local_engine_actors + self.remote_engine_actors: ray.kill(actor) for pg in self.created_placement_groups: diff --git a/vllm/v1/utils.py b/vllm/v1/utils.py index 1aa36b1a5f81..eb81a3c88fb7 100644 --- a/vllm/v1/utils.py +++ b/vllm/v1/utils.py @@ -3,6 +3,7 @@ import argparse import contextlib import multiprocessing +import threading import time import weakref from collections.abc import Callable, Sequence @@ -269,8 +270,6 @@ def wait_for_completion_or_failure( coordinator: The coordinator for data parallel. """ - from vllm.v1.engine.utils import CoreEngineActorManager, CoreEngineProcManager - try: logger.info("Waiting for API servers to complete ...") # Create a mapping of sentinels to their corresponding processes @@ -282,33 +281,40 @@ def wait_for_completion_or_failure( if coordinator: sentinel_to_proc[coordinator.proc.sentinel] = coordinator.proc - actor_run_refs = [] - if isinstance(engine_manager, CoreEngineProcManager): - for proc in engine_manager.processes: - sentinel_to_proc[proc.sentinel] = proc - elif isinstance(engine_manager, CoreEngineActorManager): - actor_run_refs = engine_manager.get_run_refs() + if engine_manager: + core_shutdown_recv, core_shutdown_send = connection.Pipe(duplex=False) + + def monitor_engines(): + try: + engine_manager.monitor_engine_liveness() + finally: + core_shutdown_send.close() + core_shutdown_recv.close() + + # start monitor for engine liveness + threading.Thread(target=monitor_engines, daemon=True).start() + sentinel_to_proc[core_shutdown_recv] = None # type: ignore[assignment] # Check if any process terminates - while sentinel_to_proc or actor_run_refs: - # Wait for any process to terminate - ready_sentinels: list[Any] = connection.wait(sentinel_to_proc, timeout=5) + while sentinel_to_proc: + # Wait for any process to terminate (or engine shutdown signal) + ready_sentinels: list[Any] = connection.wait(sentinel_to_proc) # Process any terminated processes for sentinel in ready_sentinels: proc = sentinel_to_proc.pop(sentinel) # Check if process exited with error - if proc.exitcode != 0: + if proc is not None and proc.exitcode != 0: raise RuntimeError( f"Process {proc.name} (PID: {proc.pid}) " f"died with exit code {proc.exitcode}" ) - - if actor_run_refs: - import ray - - _, actor_run_refs = ray.wait(actor_run_refs, timeout=5) + if engine_manager and engine_manager.failed_proc_name is not None: + raise RuntimeError( + f"Engine core process {engine_manager.failed_proc_name} " + "died unexpectedly." + ) except KeyboardInterrupt: logger.info("Received KeyboardInterrupt, shutting down API servers...") From ab1a6a43fa9500697dd01e73aa372c8777cd7a5b Mon Sep 17 00:00:00 2001 From: mikaylagawarecki Date: Mon, 30 Mar 2026 14:20:13 -0400 Subject: [PATCH 0029/1601] [3/n] Migrate cutlass/scaled_mm_entry.cu torch stable ABI (#37221) Signed-off-by: Mikayla Gawarecki --- CMakeLists.txt | 418 ++++++++-------- csrc/cutlass_extensions/common.hpp | 12 +- .../epilogue/scaled_mm_epilogues_c3x.hpp | 52 +- .../epilogue/scaled_mm_epilogues_c2x.hpp | 36 +- csrc/libtorch_stable/ops.h | 57 +++ .../w8a8/cutlass/c3x/cutlass_gemm_caller.cuh | 22 +- .../w8a8/cutlass/c3x/scaled_mm.cuh | 0 .../cutlass/c3x/scaled_mm_azp_sm90_int8.cu | 13 +- .../c3x/scaled_mm_blockwise_sm100_fp8.cu | 22 + ...scaled_mm_blockwise_sm100_fp8_dispatch.cuh | 20 +- .../c3x/scaled_mm_blockwise_sm120_fp8.cu | 22 + ...scaled_mm_blockwise_sm120_fp8_dispatch.cuh | 20 +- .../c3x/scaled_mm_blockwise_sm90_fp8.cu | 23 + .../scaled_mm_blockwise_sm90_fp8_dispatch.cuh | 22 +- .../w8a8/cutlass/c3x/scaled_mm_helper.hpp | 37 +- .../w8a8/cutlass/c3x/scaled_mm_kernels.hpp | 52 ++ .../w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu | 24 + .../c3x/scaled_mm_sm100_fp8_dispatch.cuh | 43 +- .../w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu | 25 + .../c3x/scaled_mm_sm120_fp8_dispatch.cuh | 30 +- .../w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu | 24 + .../c3x/scaled_mm_sm90_fp8_dispatch.cuh | 43 +- .../w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu | 25 + .../c3x/scaled_mm_sm90_int8_dispatch.cuh | 26 +- .../w8a8/cutlass/moe/get_group_starts.cuh | 47 +- .../w8a8/cutlass/moe/grouped_mm_c3x.cuh | 47 +- .../w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu | 79 +-- .../w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu | 79 +-- .../quantization/w8a8/cutlass/moe/moe_data.cu | 131 ++--- .../w8a8/cutlass/scaled_mm_c2x.cu | 220 +++++++++ .../w8a8/cutlass/scaled_mm_c2x.cuh | 29 +- .../cutlass/scaled_mm_c2x_sm75_dispatch.cuh | 12 +- .../cutlass/scaled_mm_c2x_sm80_dispatch.cuh | 12 +- .../scaled_mm_c2x_sm89_fp8_dispatch.cuh | 62 ++- .../scaled_mm_c2x_sm89_int8_dispatch.cuh | 54 ++- .../w8a8/cutlass/scaled_mm_c3x_sm100.cu | 11 +- .../w8a8/cutlass/scaled_mm_c3x_sm120.cu | 11 +- .../w8a8/cutlass/scaled_mm_c3x_sm90.cu | 38 ++ .../w8a8/cutlass/scaled_mm_entry.cu | 451 ++++++++++++++++++ csrc/libtorch_stable/torch_bindings.cpp | 97 ++++ csrc/libtorch_stable/torch_utils.h | 6 + csrc/ops.h | 45 -- .../c3x/scaled_mm_blockwise_sm100_fp8.cu | 23 - .../c3x/scaled_mm_blockwise_sm120_fp8.cu | 23 - .../c3x/scaled_mm_blockwise_sm90_fp8.cu | 24 - .../w8a8/cutlass/c3x/scaled_mm_kernels.hpp | 56 --- .../w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu | 23 - .../w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu | 24 - .../w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu | 23 - .../w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu | 24 - .../w8a8/cutlass/scaled_mm_c2x.cu | 199 -------- .../w8a8/cutlass/scaled_mm_c3x_sm90.cu | 36 -- .../w8a8/cutlass/scaled_mm_entry.cu | 420 ---------------- csrc/torch_bindings.cpp | 84 ---- 54 files changed, 1845 insertions(+), 1613 deletions(-) rename csrc/{ => libtorch_stable}/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp (91%) rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh (84%) rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/c3x/scaled_mm.cuh (100%) rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu (51%) create mode 100644 csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh (93%) create mode 100644 csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh (91%) create mode 100644 csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh (87%) rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp (55%) create mode 100644 csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp create mode 100644 csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh (90%) create mode 100644 csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh (88%) create mode 100644 csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8_dispatch.cuh (91%) create mode 100644 csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8_dispatch.cuh (86%) rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/moe/get_group_starts.cuh (63%) rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/moe/grouped_mm_c3x.cuh (78%) rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu (59%) rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu (70%) rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/moe/moe_data.cu (70%) create mode 100644 csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/scaled_mm_c2x.cuh (87%) rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/scaled_mm_c2x_sm75_dispatch.cuh (92%) rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/scaled_mm_c2x_sm80_dispatch.cuh (93%) rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_fp8_dispatch.cuh (86%) rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/scaled_mm_c2x_sm89_int8_dispatch.cuh (86%) rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu (56%) rename csrc/{ => libtorch_stable}/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu (56%) create mode 100644 csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu create mode 100644 csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu delete mode 100644 csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu delete mode 100644 csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu delete mode 100644 csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu delete mode 100644 csrc/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp delete mode 100644 csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu delete mode 100644 csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu delete mode 100644 csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu delete mode 100644 csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu delete mode 100644 csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu delete mode 100644 csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu delete mode 100644 csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index ff57473cf163..fc50b896d2fc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -340,7 +340,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_EXT_SRC "csrc/quantization/awq/gemm_kernels.cu" - "csrc/quantization/w8a8/cutlass/scaled_mm_entry.cu" "csrc/quantization/fp4/nvfp4_quant_entry.cu" "csrc/quantization/fp4/nvfp4_scaled_mm_entry.cu" "csrc/cutlass_extensions/common.cpp") @@ -490,132 +489,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") " in CUDA target architectures") endif() - - set(SCALED_MM_3X_ARCHS) - # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require - # CUDA 12.0 or later - cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) - set(SRCS - "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1") - # Let scaled_mm_c2x know it doesn't need to build these arches - list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") - message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) - message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is " - "not >= 12.0, we recommend upgrading to CUDA 12.0 or " - "later if you intend on running FP8 quantized models on " - "Hopper.") - else() - message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found " - "in CUDA target architectures") - endif() - endif() - - - # The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require - # CUDA 12.8 or later - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS - "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu" - ) - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1") - # Let scaled_mm_c2x know it doesn't need to build these arches - list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") - message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is " - "not >= 12.8, we recommend upgrading to CUDA 12.8 or " - "later if you intend on running FP8 quantized models on " - "Blackwell.") - else() - message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found " - "in CUDA target architectures") - endif() - endif() - - - # The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x) - # require CUDA 12.8 or later - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS - "csrc/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu" - "csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu" - ) - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1") - # Let scaled_mm_c2x know it doesn't need to build these arches - list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") - message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is " - "not >= 12.8, we recommend upgrading to CUDA 12.8 or " - "later if you intend on running FP8 quantized models on " - "Blackwell.") - else() - message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found " - "in CUDA target architectures") - endif() - endif() - - # - # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) - # kernels for the remaining archs that are not already built for 3x. - # (Build 8.9 for FP8) - cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS - "7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}") - # subtract out the archs that are already built for 3x - list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) - if (SCALED_MM_2X_ARCHS) - set(SRCS "csrc/quantization/w8a8/cutlass/scaled_mm_c2x.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_2X_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1") - message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}") - else() - if (SCALED_MM_3X_ARCHS) - message(STATUS "Not building scaled_mm_c2x as all archs are already built" - " for and covered by scaled_mm_c3x") - else() - message(STATUS "Not building scaled_mm_c2x as no compatible archs found " - "in CUDA target architectures") - endif() - endif() - # The nvfp4_scaled_mm_sm120 kernels for Blackwell SM12x require # CUDA 12.8 or later if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) @@ -693,55 +566,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") set(MLA_ARCHS) endif() - # CUTLASS MoE kernels - - # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works - # on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled - # if it's possible to compile MoE kernels that use its output. - cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1") - message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) - message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " - "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " - "if you intend on running FP8 quantized MoE models on Hopper.") - else() - message(STATUS "Not building grouped_mm_c3x as no compatible archs found " - "in CUDA target architectures.") - endif() - endif() - - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - set(SRCS "csrc/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${SCALED_MM_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") - message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) - message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " - "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " - "if you intend on running FP8 quantized MoE models on Blackwell.") - else() - message(STATUS "Not building grouped_mm_c3x as no compatible archs found " - "in CUDA target architectures.") - endif() - endif() - # Expert-specialization MXFP8 blockscaled grouped kernels (SM100+). if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) cuda_archs_loose_intersection(ES_MXFP8_GROUPED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") @@ -787,36 +611,6 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") "in CUDA target architectures.") endif() - # moe_data.cu is used by all CUTLASS MoE kernels. - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") - endif() - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) - set(SRCS "csrc/quantization/w8a8/cutlass/moe/moe_data.cu") - set_gencode_flags_for_srcs( - SRCS "${SRCS}" - CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") - list(APPEND VLLM_EXT_SRC "${SRCS}") - message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}") - else() - if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) - message(STATUS "Not building moe_data as CUDA Compiler version is " - "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " - "if you intend on running FP8 quantized MoE models on Hopper or Blackwell.") - else() - message(STATUS "Not building moe_data as no compatible archs found " - "in CUDA target architectures.") - endif() - endif() - - if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") - else() - cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") - endif() - # # Machete kernels @@ -964,7 +758,9 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # _C_stable_libtorch extension (ops registered via STABLE_TORCH_LIBRARY) # set(VLLM_STABLE_EXT_SRC - "csrc/libtorch_stable/torch_bindings.cpp") + "csrc/libtorch_stable/torch_bindings.cpp" + "csrc/cutlass_extensions/common.cpp" + "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_entry.cu") if(VLLM_GPU_LANG STREQUAL "CUDA") list(APPEND VLLM_STABLE_EXT_SRC @@ -979,6 +775,209 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") CUDA_ARCHS "${CUDA_ARCHS}") endif() + # + # CUTLASS scaled_mm kernels (moved from _C to _C_stable_libtorch) + # + set(SCALED_MM_3X_ARCHS) + # The cutlass_scaled_mm kernels for Hopper (c3x, i.e. CUTLASS 3.x) require + # CUDA 12.0 or later + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a;" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) + set(SRCS + "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm90.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_fp8.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm90_int8.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM90=1") + # Let scaled_mm_c2x know it doesn't need to build these arches + list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") + message(STATUS "Building scaled_mm_c3x_sm90 for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.0 AND SCALED_MM_ARCHS) + message(STATUS "Not building scaled_mm_c3x_sm90 as CUDA Compiler version is " + "not >= 12.0, we recommend upgrading to CUDA 12.0 or " + "later if you intend on running FP8 quantized models on " + "Hopper.") + else() + message(STATUS "Not building scaled_mm_c3x_sm90 as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + + # The cutlass_scaled_mm kernels for Blackwell SM12x (c3x, i.e. CUTLASS 3.x) require + # CUDA 12.8 or later + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "12.0a;12.1a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS + "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm120.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu" + ) + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM120=1") + # Let scaled_mm_c2x know it doesn't need to build these arches + list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") + message(STATUS "Building scaled_mm_c3x_sm120 for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building scaled_mm_c3x_sm120 as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or " + "later if you intend on running FP8 quantized models on " + "Blackwell.") + else() + message(STATUS "Not building scaled_mm_c3x_120 as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + + # The cutlass_scaled_mm kernels for Blackwell SM100 (c3x, i.e. CUTLASS 3.x) + # require CUDA 12.8 or later + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS + "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c3x_sm100.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu" + "csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu" + ) + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_SM100=1") + # Let scaled_mm_c2x know it doesn't need to build these arches + list(APPEND SCALED_MM_3X_ARCHS "${SCALED_MM_ARCHS}") + message(STATUS "Building scaled_mm_c3x_sm100 for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building scaled_mm_c3x_sm100 as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or " + "later if you intend on running FP8 quantized models on " + "Blackwell.") + else() + message(STATUS "Not building scaled_mm_c3x_100 as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + # + # For the cutlass_scaled_mm kernels we want to build the c2x (CUTLASS 2.x) + # kernels for the remaining archs that are not already built for 3x. + # (Build 8.9 for FP8) + cuda_archs_loose_intersection(SCALED_MM_2X_ARCHS + "7.5;8.0;8.7;8.9+PTX" "${CUDA_ARCHS}") + # subtract out the archs that are already built for 3x + list(REMOVE_ITEM SCALED_MM_2X_ARCHS ${SCALED_MM_3X_ARCHS}) + if (SCALED_MM_2X_ARCHS) + set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/scaled_mm_c2x.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_2X_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_SCALED_MM_C2X=1") + message(STATUS "Building scaled_mm_c2x for archs: ${SCALED_MM_2X_ARCHS}") + else() + if (SCALED_MM_3X_ARCHS) + message(STATUS "Not building scaled_mm_c2x as all archs are already built" + " for and covered by scaled_mm_c3x") + else() + message(STATUS "Not building scaled_mm_c2x as no compatible archs found " + "in CUDA target architectures") + endif() + endif() + + # + # CUTLASS MoE kernels (moved from _C to _C_stable_libtorch) + # + + # The MoE kernel cutlass_moe_mm requires CUDA 12.3 or later (and ONLY works + # on Hopper). get_cutlass_(batched_)moe_mm_data should only be compiled + # if it's possible to compile MoE kernels that use its output. + cuda_archs_loose_intersection(SCALED_MM_ARCHS "9.0a" "${CUDA_ARCHS}") + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) + set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm90.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM90=1") + message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND SCALED_MM_ARCHS) + message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " + "if you intend on running FP8 quantized MoE models on Hopper.") + else() + message(STATUS "Not building grouped_mm_c3x as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() + + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0f;11.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(SCALED_MM_ARCHS "10.0a;10.1a;10.3a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/grouped_mm_c3x_sm100.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${SCALED_MM_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + list(APPEND VLLM_GPU_FLAGS "-DENABLE_CUTLASS_MOE_SM100=1") + message(STATUS "Building grouped_mm_c3x for archs: ${SCALED_MM_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.8 AND SCALED_MM_ARCHS) + message(STATUS "Not building grouped_mm_c3x kernels as CUDA Compiler version is " + "not >= 12.8, we recommend upgrading to CUDA 12.8 or later " + "if you intend on running FP8 quantized MoE models on Blackwell.") + else() + message(STATUS "Not building grouped_mm_c3x as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() + + # moe_data.cu is used by all CUTLASS MoE kernels. + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 13.0) + cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0f;11.0f;12.0f" "${CUDA_ARCHS}") + else() + cuda_archs_loose_intersection(CUTLASS_MOE_DATA_ARCHS "9.0a;10.0a;10.1a;10.3a;12.0a;12.1a" "${CUDA_ARCHS}") + endif() + if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) + set(SRCS "csrc/libtorch_stable/quantization/w8a8/cutlass/moe/moe_data.cu") + set_gencode_flags_for_srcs( + SRCS "${SRCS}" + CUDA_ARCHS "${CUTLASS_MOE_DATA_ARCHS}") + list(APPEND VLLM_STABLE_EXT_SRC "${SRCS}") + message(STATUS "Building moe_data for archs: ${CUTLASS_MOE_DATA_ARCHS}") + else() + if (NOT ${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER_EQUAL 12.3 AND CUTLASS_MOE_DATA_ARCHS) + message(STATUS "Not building moe_data as CUDA Compiler version is " + "not >= 12.3, we recommend upgrading to CUDA 12.3 or later " + "if you intend on running FP8 quantized MoE models on Hopper or Blackwell.") + else() + message(STATUS "Not building moe_data as no compatible archs found " + "in CUDA target architectures.") + endif() + endif() + message(STATUS "Enabling C_stable extension.") define_extension_target( _C_stable_libtorch @@ -987,6 +986,7 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") SOURCES ${VLLM_STABLE_EXT_SRC} COMPILE_FLAGS ${VLLM_GPU_FLAGS} ARCHITECTURES ${VLLM_GPU_ARCHES} + INCLUDE_DIRECTORIES ${CUTLASS_INCLUDE_DIR} ${CUTLASS_TOOLS_UTIL_INCLUDE_DIR} USE_SABI 3 WITH_SOABI) @@ -1000,6 +1000,10 @@ if(VLLM_GPU_LANG STREQUAL "CUDA") # Needed to use cuda APIs from C-shim target_compile_definitions(_C_stable_libtorch PRIVATE USE_CUDA) + + # Needed by CUTLASS kernels + target_compile_definitions(_C_stable_libtorch PRIVATE + CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL=1) endif() # diff --git a/csrc/cutlass_extensions/common.hpp b/csrc/cutlass_extensions/common.hpp index 91c215071f6e..c32d3a0606af 100644 --- a/csrc/cutlass_extensions/common.hpp +++ b/csrc/cutlass_extensions/common.hpp @@ -6,14 +6,16 @@ #include #include +#include + /** * Helper function for checking CUTLASS errors */ -#define CUTLASS_CHECK(status) \ - { \ - cutlass::Status error = status; \ - TORCH_CHECK(error == cutlass::Status::kSuccess, \ - cutlassGetStatusString(error)); \ +#define CUTLASS_CHECK(status) \ + { \ + cutlass::Status error = status; \ + STD_TORCH_CHECK(error == cutlass::Status::kSuccess, \ + cutlassGetStatusString(error)); \ } inline int get_cuda_max_shared_memory_per_block_opt_in(int const device) { diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp index cf79507e1997..c2ddcea6d027 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp +++ b/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp @@ -3,6 +3,14 @@ #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c3x.hpp" #include "cutlass_extensions/epilogue/broadcast_load_epilogue_array_c3x.hpp" +// This header is shared by both _C (unstable ABI) and _C_stable_libtorch +// (stable ABI) targets. When compiled under the stable ABI target, +// TORCH_TARGET_VERSION is defined and Tensor is unavailable, so we +// use torch::stable::Tensor instead. +#ifdef TORCH_TARGET_VERSION + #include +#endif + /* This file defines custom epilogues for fusing channel scales, token scales, bias, and activation zero-points onto a GEMM operation using the @@ -15,6 +23,12 @@ namespace vllm::c3x { +#ifdef TORCH_TARGET_VERSION +using TensorType = torch::stable::Tensor; +#else +using TensorType = torch::Tensor; +#endif + using namespace cute; template @@ -84,7 +98,7 @@ struct ScaledEpilogueBase { // from a tensor. It can handle both row and column, as well as row/column or // scalar cases. template - static auto args_from_tensor(torch::Tensor const& tensor) { + static auto args_from_tensor(TensorType const& tensor) { using Arguments = typename Descriptor::Arguments; auto* data_ptr = static_cast(tensor.data_ptr()); if constexpr (std::is_same_v> || @@ -100,7 +114,7 @@ struct ScaledEpilogueBase { // This overload handles the case where there might not be a tensor, in which // case a nullptr is passed and a constant (0) is used. template - static auto args_from_tensor(std::optional const& tensor) { + static auto args_from_tensor(std::optional const& tensor) { using Arguments = typename Descriptor::Arguments; auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; static_assert(std::is_same_v> || @@ -158,8 +172,8 @@ struct ScaledEpilogue cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { + static ArgumentType prepare_args(TensorType const& a_scales, + TensorType const& b_scales) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); @@ -203,9 +217,9 @@ struct ScaledEpilogueBias cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { + static ArgumentType prepare_args(TensorType const& a_scales, + TensorType const& b_scales, + TensorType const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -246,9 +260,9 @@ struct ScaledEpilogueColumnBias cutlass::epilogue::fusion::Sm90EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { + static ArgumentType prepare_args(TensorType const& a_scales, + TensorType const& b_scales, + TensorType const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -304,10 +318,10 @@ struct ScaledEpilogueBiasAzp EVTComputeScaleB, Bias>; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - std::optional const& bias) { + static ArgumentType prepare_args(TensorType const& a_scales, + TensorType const& b_scales, + TensorType const& azp_adj, + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -380,11 +394,11 @@ struct ScaledEpilogueBiasAzpToken EVTComputeScaleB, Bias>; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - torch::Tensor const& azp, - std::optional const& bias) { + static ArgumentType prepare_args(TensorType const& a_scales, + TensorType const& b_scales, + TensorType const& azp_adj, + TensorType const& azp, + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); diff --git a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp b/csrc/libtorch_stable/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp similarity index 91% rename from csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp rename to csrc/libtorch_stable/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp index ad8c0067d4a9..f6737a73d48a 100644 --- a/csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp +++ b/csrc/libtorch_stable/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp @@ -1,5 +1,7 @@ #pragma once +#include + #include "cutlass_extensions/epilogue/broadcast_load_epilogue_c2x.hpp" /* @@ -52,7 +54,7 @@ struct ScaledEpilogueBase { // from a tensor. It can handle both row and column, as well as row/column or // scalar cases. template - static auto args_from_tensor(torch::Tensor const& tensor) { + static auto args_from_tensor(torch::stable::Tensor const& tensor) { using Arguments = typename Descriptor::Arguments; auto* data_ptr = static_cast(tensor.data_ptr()); if constexpr (std::is_same_v> || @@ -68,7 +70,8 @@ struct ScaledEpilogueBase { // This overload handles the case where there might not be a tensor, in which // case a nullptr is passed and a constant (0) is used. template - static auto args_from_tensor(std::optional const& tensor) { + static auto args_from_tensor( + std::optional const& tensor) { static_assert(std::is_same_v>); using Arguments = typename Descriptor::Arguments; auto* data_ptr = tensor ? static_cast(tensor->data_ptr()) : nullptr; @@ -117,8 +120,8 @@ struct ScaledEpilogue cutlass::epilogue::threadblock::Sm80EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { + static ArgumentType prepare_args(torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); @@ -160,9 +163,9 @@ struct ScaledEpilogueBias using EVTCompute = cutlass::epilogue::threadblock::Sm80EVT; using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& bias) { + static ArgumentType prepare_args(torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + torch::stable::Tensor const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -220,10 +223,11 @@ struct ScaledEpilogueBiasAzp using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - std::optional const& bias) { + static ArgumentType prepare_args( + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + torch::stable::Tensor const& azp_adj, + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); @@ -298,11 +302,11 @@ struct ScaledEpilogueBiasAzpToken using ArgumentType = typename EVTCompute::Arguments; - static ArgumentType prepare_args(torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - torch::Tensor const& azp, - std::optional const& bias) { + static ArgumentType prepare_args( + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + torch::stable::Tensor const& azp_adj, torch::stable::Tensor const& azp, + std::optional const& bias) { auto a_args = SUPER::template args_from_tensor(a_scales); auto b_args = SUPER::template args_from_tensor(b_scales); auto bias_args = SUPER::template args_from_tensor(bias); diff --git a/csrc/libtorch_stable/ops.h b/csrc/libtorch_stable/ops.h index b74c5c505f87..914f033b8984 100644 --- a/csrc/libtorch_stable/ops.h +++ b/csrc/libtorch_stable/ops.h @@ -27,4 +27,61 @@ void per_token_group_quant_int8(const torch::stable::Tensor& input, torch::stable::Tensor& output_s, int64_t group_size, double eps, double int8_min, double int8_max); + +bool cutlass_scaled_mm_supports_fp8(int64_t cuda_device_capability); +bool cutlass_scaled_mm_supports_block_fp8(int64_t cuda_device_capability); +bool cutlass_group_gemm_supported(int64_t cuda_device_capability); + +void cutlass_scaled_mm(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + std::optional const& bias); + +void cutlass_moe_mm(torch::stable::Tensor& out_tensors, + torch::stable::Tensor const& a_tensors, + torch::stable::Tensor const& b_tensors, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + torch::stable::Tensor const& expert_offsets, + torch::stable::Tensor const& problem_sizes, + torch::stable::Tensor const& a_strides, + torch::stable::Tensor const& b_strides, + torch::stable::Tensor const& c_strides, bool per_act_token, + bool per_out_ch); + +void cutlass_scaled_mm_azp(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + torch::stable::Tensor const& azp_adj, + std::optional const& azp, + std::optional const& bias); + +void get_cutlass_moe_mm_data( + const torch::stable::Tensor& topk_ids, + torch::stable::Tensor& expert_offsets, + torch::stable::Tensor& problem_sizes1, + torch::stable::Tensor& problem_sizes2, + torch::stable::Tensor& input_permutation, + torch::stable::Tensor& output_permutation, const int64_t num_experts, + const int64_t n, const int64_t k, + const std::optional& blockscale_offsets, + const bool is_gated); + +void get_cutlass_moe_mm_problem_sizes_from_expert_offsets( + const torch::stable::Tensor& expert_first_token_offset, + torch::stable::Tensor& problem_sizes1, + torch::stable::Tensor& problem_sizes2, const int64_t n, const int64_t k, + const bool swap_ab); + +void get_cutlass_batched_moe_mm_data( + torch::stable::Tensor& expert_offsets, + torch::stable::Tensor& problem_sizes1, + torch::stable::Tensor& problem_sizes2, + const torch::stable::Tensor& expert_num_tokens, + const int64_t num_local_experts, const int64_t padded_m, const int64_t n, + const int64_t k); #endif diff --git a/csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh similarity index 84% rename from csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh index 26de32ce2b16..ae40c0989e03 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/cutlass_gemm_caller.cuh @@ -2,9 +2,10 @@ // clang-format will break include orders // clang-format off -#include +#include +#include -#include +#include "libtorch_stable/torch_utils.h" #include "cutlass/cutlass.h" @@ -25,14 +26,14 @@ namespace vllm::c3x { static inline cute::Shape get_problem_shape( - torch::Tensor const& a, torch::Tensor const& b) { + torch::stable::Tensor const& a, torch::stable::Tensor const& b) { int32_t m = a.size(0), n = b.size(1), k = a.size(1); return {m, n, k, 1}; } template void cutlass_gemm_caller( - torch::Device device, cute::Shape prob_shape, + torch::stable::Device device, cute::Shape prob_shape, typename GemmKernel::MainloopArguments mainloop_args, typename GemmKernel::EpilogueArguments epilogue_args, typename GemmKernel::TileSchedulerArguments scheduler = {}) { @@ -50,19 +51,20 @@ void cutlass_gemm_caller( CUTLASS_CHECK(gemm_op.can_implement(args)); size_t workspace_size = gemm_op.get_workspace_size(args); - auto const workspace_options = - torch::TensorOptions().dtype(torch::kUInt8).device(device); - auto workspace = torch::empty(workspace_size, workspace_options); + auto workspace = + torch::stable::empty(workspace_size, torch::headeronly::ScalarType::Byte, + std::nullopt, device); - auto stream = at::cuda::getCurrentCUDAStream(device.index()); + auto stream = get_current_cuda_stream(device.index()); cutlass::Status status = gemm_op.run(args, workspace.data_ptr(), stream); CUTLASS_CHECK(status); } template -void cutlass_gemm_caller(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, +void cutlass_gemm_caller(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, EpilogueArgs&&... epilogue_params) { using ElementAB = typename Gemm::ElementAB; using ElementC = typename Gemm::ElementC; diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh similarity index 100% rename from csrc/quantization/w8a8/cutlass/c3x/scaled_mm.cuh rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm.cuh diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu similarity index 51% rename from csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu index 4cd38f4975df..bc088cf633f4 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_azp_sm90_int8.cu @@ -4,13 +4,12 @@ namespace vllm { -void cutlass_scaled_mm_azp_sm90_int8(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - torch::Tensor const& azp_adj, - std::optional const& azp, - std::optional const& bias) { +void cutlass_scaled_mm_azp_sm90_int8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj, + std::optional const& azp, + std::optional const& bias) { if (azp) { return cutlass_scaled_mm_sm90_int8_epilogue< c3x::ScaledEpilogueBiasAzpToken>(out, a, b, a_scales, b_scales, azp_adj, diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu new file mode 100644 index 000000000000..f3df69850ec6 --- /dev/null +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8.cu @@ -0,0 +1,22 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm100_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm100_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { + if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) { + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half); + cutlass_gemm_blockwise_sm100_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh similarity index 93% rename from csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh index db3b26c084ee..cf84fc3a6de1 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm100_fp8_dispatch.cuh @@ -1,5 +1,7 @@ #pragma once +#include + #include "cuda_utils.h" #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" @@ -130,10 +132,10 @@ struct cutlass_3x_gemm_fp8_blockwise { }; template -void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { static constexpr bool swap_ab = Gemm::swap_ab; using GemmKernel = typename Gemm::GemmKernel; using StrideA = typename Gemm::GemmKernel::StrideA; @@ -200,11 +202,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, } template -void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_gemm_blockwise_sm100_fp8_dispatch(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { int32_t m = a.size(0), n = b.size(1), k = a.size(1), sms; cudaDeviceGetAttribute(&sms, cudaDevAttrMultiProcessorCount, a.get_device()); diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu new file mode 100644 index 000000000000..7ceb0697df2a --- /dev/null +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8.cu @@ -0,0 +1,22 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm120_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm120_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { + if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) { + cutlass_gemm_blockwise_sm120_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half); + cutlass_gemm_blockwise_sm120_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh similarity index 91% rename from csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh index 0701cb5866cf..14de9b1e4fd9 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm120_fp8_dispatch.cuh @@ -1,5 +1,7 @@ #pragma once +#include + #include "cuda_utils.h" #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" @@ -138,10 +140,10 @@ struct sm120_blockwise_fp8_config_M64 { }; template -void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { using GemmKernel = typename Gemm::GemmKernel; using StrideA = typename Gemm::GemmKernel::StrideA; using StrideB = typename Gemm::GemmKernel::StrideB; @@ -196,11 +198,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, } template -void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_gemm_blockwise_sm120_fp8_dispatch(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { int M = a.size(0); if (M <= 256) { using Gemm = typename sm120_blockwise_fp8_config_M64::Gemm; diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu new file mode 100644 index 000000000000..d3318c487675 --- /dev/null +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8.cu @@ -0,0 +1,23 @@ + +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_blockwise_sm90_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_blockwise_sm90_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { + if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) { + cutlass_gemm_blockwise_sm90_fp8_dispatch( + out, a, b, a_scales, b_scales); + + } else { + STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half); + cutlass_gemm_blockwise_sm90_fp8_dispatch( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh similarity index 87% rename from csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh index c40d49966271..cf62e81fd75b 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_blockwise_sm90_fp8_dispatch.cuh @@ -1,5 +1,7 @@ #pragma once +#include + #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" @@ -101,10 +103,10 @@ struct cutlass_3x_gemm_fp8_blockwise { }; template -void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_gemm_caller_blockwise(torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { using GemmKernel = typename Gemm::GemmKernel; using StrideA = typename Gemm::GemmKernel::StrideA; using StrideB = typename Gemm::GemmKernel::StrideB; @@ -120,7 +122,7 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, int32_t m = a.size(0), n = b.size(1), k = a.size(1); - TORCH_CHECK(m % 4 == 0, "m must be divisible by 4"); + STD_TORCH_CHECK(m % 4 == 0, "m must be divisible by 4"); StrideA a_stride; StrideB b_stride; @@ -161,11 +163,11 @@ void cutlass_gemm_caller_blockwise(torch::Tensor& out, torch::Tensor const& a, } template -void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales) { +void cutlass_gemm_blockwise_sm90_fp8_dispatch(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales) { // TODO: better heuristics cutlass_gemm_caller_blockwise, diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp similarity index 55% rename from csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp index 2204a49257b0..adb3de50fc1b 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_helper.hpp @@ -1,52 +1,57 @@ -#include +#include +#include #include "cuda_utils.h" #include "cutlass_extensions/common.hpp" template -void dispatch_scaled_mm(torch::Tensor& c, torch::Tensor const& a, - torch::Tensor const& b, torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - std::optional const& bias, +void dispatch_scaled_mm(torch::stable::Tensor& c, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + std::optional const& bias, Fp8Func fp8_func, Int8Func int8_func, BlockwiseFunc blockwise_func) { - TORCH_CHECK(a_scales.dtype() == torch::kFloat32); - TORCH_CHECK(b_scales.dtype() == torch::kFloat32); + STD_TORCH_CHECK(a_scales.scalar_type() == + torch::headeronly::ScalarType::Float); + STD_TORCH_CHECK(b_scales.scalar_type() == + torch::headeronly::ScalarType::Float); int M = a.size(0), N = b.size(1), K = a.size(1); if ((a_scales.numel() == 1 || a_scales.numel() == a.size(0)) && (b_scales.numel() == 1 || b_scales.numel() == b.size(1))) { // Standard per-tensor/per-token/per-channel scaling - TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); - if (a.dtype() == torch::kFloat8_e4m3fn) { + STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (a.scalar_type() == torch::headeronly::ScalarType::Float8_e4m3fn) { fp8_func(c, a, b, a_scales, b_scales, bias); } else { - TORCH_CHECK(a.dtype() == torch::kInt8); + STD_TORCH_CHECK(a.scalar_type() == torch::headeronly::ScalarType::Char); if constexpr (!std::is_same_v) { int8_func(c, a, b, a_scales, b_scales, bias); } else { int32_t version_num = get_sm_version_num(); - TORCH_CHECK( + STD_TORCH_CHECK( false, "Int8 not supported on SM", version_num, ". Use FP8 quantization instead, or run on older arch (SM < 100)."); } } } else { - TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); - TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); + STD_TORCH_CHECK(a_scales.dim() == 2, "a scale must be 2d tensor."); + STD_TORCH_CHECK(b_scales.dim() == 2, "b scale must be 2d tensor."); int32_t version_num = get_sm_version_num(); if (version_num >= 90) { - TORCH_CHECK( + STD_TORCH_CHECK( a.size(0) == a_scales.size(0) && cuda_utils::ceil_div(a.size(1), int64_t(128)) == a_scales.size(1), "a_scale_group_shape must be [1, 128]."); - TORCH_CHECK( + STD_TORCH_CHECK( cuda_utils::ceil_div(b.size(0), int64_t(128)) == b_scales.size(0) && cuda_utils::ceil_div(b.size(1), int64_t(128)) == b_scales.size(1), "b_scale_group_shape must be [128, 128]."); } - TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); + STD_TORCH_CHECK(!bias, "Bias not yet supported blockwise scaled_mm"); blockwise_func(c, a, b, a_scales, b_scales); } } diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp new file mode 100644 index 000000000000..a3a3eb3e2875 --- /dev/null +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_kernels.hpp @@ -0,0 +1,52 @@ +#pragma once + +#include + +namespace vllm { + +void cutlass_scaled_mm_sm90_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + std::optional const& bias); + +void cutlass_scaled_mm_sm90_int8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + std::optional const& bias); + +void cutlass_scaled_mm_azp_sm90_int8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, torch::stable::Tensor const& azp_adj, + std::optional const& azp, + std::optional const& bias); + +void cutlass_scaled_mm_blockwise_sm90_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales); + +void cutlass_scaled_mm_sm100_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + std::optional const& bias); + +void cutlass_scaled_mm_sm120_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + std::optional const& bias); + +void cutlass_scaled_mm_blockwise_sm100_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales); + +void cutlass_scaled_mm_blockwise_sm120_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales); +} // namespace vllm diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu new file mode 100644 index 000000000000..e910103c4eae --- /dev/null +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8.cu @@ -0,0 +1,24 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_sm100_fp8_dispatch.cuh" + +namespace vllm { + +void cutlass_scaled_mm_sm100_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + std::optional const& bias) { + STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (bias) { + STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(), + "currently bias dtype must match output dtype ", + out.scalar_type()); + return cutlass_scaled_mm_sm100_fp8_epilogue(out, a, b, a_scales, + b_scales, *bias); + } else { + return cutlass_scaled_mm_sm100_fp8_epilogue(out, a, b, a_scales, + b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh similarity index 90% rename from csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh index 311cd4bd41c5..46c36d13ece4 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm100_fp8_dispatch.cuh @@ -1,5 +1,7 @@ #pragma once +#include + #include "scaled_mm.cuh" #include "cutlass_gemm_caller.cuh" #include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" @@ -192,8 +194,9 @@ struct sm100_fp8_config_M16_swap_ab { }; template -void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, - torch::Tensor const& b, +void cutlass_gemm_caller_sm100_fp8(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, EpilogueArgs&&... epilogue_params) { static constexpr bool swap_ab = Gemm::swap_ab; using ElementAB = typename Gemm::ElementAB; @@ -237,15 +240,15 @@ void cutlass_gemm_caller_sm100_fp8(torch::Tensor& out, torch::Tensor const& a, template -inline void cutlass_gemm_sm100_fp8_dispatch(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales, - EpilogueArgs&&... args) { +inline void cutlass_gemm_sm100_fp8_dispatch( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, EpilogueArgs&&... args) { static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + STD_TORCH_CHECK(a.scalar_type() == + torch::headeronly::ScalarType::Float8_e4m3fn); + STD_TORCH_CHECK(b.scalar_type() == + torch::headeronly::ScalarType::Float8_e4m3fn); using Cutlass3xGemmDefault = typename sm100_fp8_config_default -void cutlass_scaled_mm_sm100_fp8_epilogue(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b, - torch::Tensor const& a_scales, - torch::Tensor const& b_scales, +void cutlass_scaled_mm_sm100_fp8_epilogue(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, + torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, EpilogueArgs&&... epilogue_args) { - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + STD_TORCH_CHECK(a.scalar_type() == + torch::headeronly::ScalarType::Float8_e4m3fn); + STD_TORCH_CHECK(b.scalar_type() == + torch::headeronly::ScalarType::Float8_e4m3fn); - if (out.dtype() == torch::kBFloat16) { + if (out.scalar_type() == torch::headeronly::ScalarType::BFloat16) { return cutlass_gemm_sm100_fp8_dispatch( out, a, b, a_scales, b_scales, std::forward(epilogue_args)...); } else { - TORCH_CHECK(out.dtype() == torch::kFloat16); + STD_TORCH_CHECK(out.scalar_type() == torch::headeronly::ScalarType::Half); return cutlass_gemm_sm100_fp8_dispatch( out, a, b, a_scales, b_scales, diff --git a/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu new file mode 100644 index 000000000000..fb84faa2a41a --- /dev/null +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8.cu @@ -0,0 +1,25 @@ +#include "scaled_mm_kernels.hpp" +#include "scaled_mm_sm120_fp8_dispatch.cuh" +#include "cutlass_extensions/epilogue/scaled_mm_epilogues_c3x.hpp" + +namespace vllm { + +void cutlass_scaled_mm_sm120_fp8( + torch::stable::Tensor& out, torch::stable::Tensor const& a, + torch::stable::Tensor const& b, torch::stable::Tensor const& a_scales, + torch::stable::Tensor const& b_scales, + std::optional const& bias) { + STD_TORCH_CHECK(a_scales.is_contiguous() && b_scales.is_contiguous()); + if (bias) { + STD_TORCH_CHECK(bias->scalar_type() == out.scalar_type(), + "currently bias dtype must match output dtype ", + out.scalar_type()); + return cutlass_scaled_mm_sm120_fp8_epilogue( + out, a, b, a_scales, b_scales, *bias); + } else { + return cutlass_scaled_mm_sm120_fp8_epilogue( + out, a, b, a_scales, b_scales); + } +} + +} // namespace vllm diff --git a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh similarity index 88% rename from csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh rename to csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh index 37846a87bbfb..245f5c10fcad 100644 --- a/csrc/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh +++ b/csrc/libtorch_stable/quantization/w8a8/cutlass/c3x/scaled_mm_sm120_fp8_dispatch.cuh @@ -1,5 +1,7 @@ #pragma once +#include + #include "scaled_mm.cuh" #include "cutlass_gemm_caller.cuh" @@ -138,13 +140,15 @@ struct sm120_fp8_config_M16 { template typename Epilogue, typename... EpilogueArgs> -inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out, - torch::Tensor const& a, - torch::Tensor const& b, +inline void cutlass_gemm_sm120_fp8_dispatch(torch::stable::Tensor& out, + torch::stable::Tensor const& a, + torch::stable::Tensor const& b, EpilogueArgs&&... args) { static_assert(std::is_same()); - TORCH_CHECK(a.dtype() == torch::kFloat8_e4m3fn); - TORCH_CHECK(b.dtype() == torch::kFloat8_e4m3fn); + STD_TORCH_CHECK(a.scalar_type() == + torch::headeronly::ScalarType::Float8_e4m3fn); + STD_TORCH_CHECK(b.scalar_type() == + torch::headeronly::ScalarType::Float8_e4m3fn); int M = a.size(0); @@ -177,19 +181,21 @@ inline void cutlass_gemm_sm120_fp8_dispatch(torch::Tensor& out, template