From 9bbcb88bc104f471775bc39545f3915d5fa2af32 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Sat, 2 May 2026 09:08:48 +0430 Subject: [PATCH] chore: sync to upstream 985961345a13f3e3bb15d29c94b011ba9a6b858b --- CMakeLists.txt | 1 - aphrodite/_aiter_ops.py | 124 +- aphrodite/_custom_ops.py | 27 +- aphrodite/compilation/backends.py | 3 +- aphrodite/compilation/cuda_graph.py | 9 +- aphrodite/compilation/decorators.py | 89 +- .../passes/fusion/act_quant_fusion.py | 1 + .../passes/fusion/allreduce_rms_fusion.py | 198 ++- .../passes/fusion/collective_fusion.py | 319 +++- .../passes/fusion/sequence_parallelism.py | 10 +- aphrodite/compilation/passes/pass_manager.py | 8 +- aphrodite/compilation/wrapper.py | 21 +- aphrodite/config/aphrodite.py | 37 +- aphrodite/config/attention.py | 62 +- aphrodite/config/model.py | 2 +- aphrodite/config/parallel.py | 28 + aphrodite/config/speculative.py | 59 +- .../device_communicators/all2all.py | 25 +- .../distributed/eplb/eplb_communicator.py | 386 ++--- .../distributed/eplb/rebalance_execute.py | 13 +- .../kv_connector/v1/multi_connector.py | 44 +- .../kv_connector/v1/nixl/scheduler.py | 87 +- .../kv_connector/v1/nixl/worker.py | 2 +- .../kv_connector/v1/offloading/common.py | 51 +- .../kv_connector/v1/offloading/scheduler.py | 508 ++++-- .../kv_connector/v1/offloading/worker.py | 108 +- .../kv_connector/v1/offloading_connector.py | 19 +- aphrodite/distributed/parallel_state.py | 10 +- aphrodite/engine/protocol.py | 1 + aphrodite/entrypoints/anthropic/protocol.py | 4 + aphrodite/entrypoints/anthropic/serving.py | 19 + aphrodite/entrypoints/chat_utils.py | 256 ++- aphrodite/entrypoints/llm.py | 2 +- .../openai/chat_completion/batch_serving.py | 1 + .../openai/chat_completion/protocol.py | 31 +- .../openai/chat_completion/serving.py | 204 +-- aphrodite/entrypoints/openai/cli_args.py | 13 +- .../entrypoints/openai/completion/protocol.py | 3 + .../entrypoints/openai/completion/serving.py | 9 +- .../entrypoints/openai/engine/protocol.py | 9 + .../entrypoints/openai/engine/serving.py | 13 + aphrodite/entrypoints/openai/fingerprint.py | 81 + .../entrypoints/openai/generate/api_router.py | 8 + .../openai/parser/harmony_utils.py | 9 +- .../entrypoints/openai/responses/serving.py | 612 +------ .../openai/responses/streaming_events.py | 450 +++++- aphrodite/entrypoints/serve/render/serving.py | 2 +- aphrodite/env_override.py | 125 +- aphrodite/envs.py | 47 +- aphrodite/inputs/engine.py | 19 + aphrodite/inputs/llm.py | 11 + aphrodite/lora/worker_manager.py | 34 +- .../kernels/linear/scaled_mm/pytorch.py | 10 +- .../layers/attention/mla_attention.py | 594 +------ .../model_executor/layers/batch_invariant.py | 20 +- .../layers/deepseek_compressor.py | 9 +- .../layers/deepseek_v4_attention.py | 152 +- .../model_executor/layers/fla/ops/kda.py | 12 +- .../layers/fused_moe/__init__.py | 18 +- .../layers/fused_moe/all2all_utils.py | 30 +- .../model_executor/layers/fused_moe/config.py | 40 +- ...880,device_name=NVIDIA_H100_80GB_HBM3.json | 147 ++ .../experts/gpt_oss_triton_kernels_moe.py | 5 +- .../fused_moe/experts/trtllm_fp8_moe.py | 52 +- .../fused_moe/experts/trtllm_mxfp4_moe.py | 47 +- .../fused_moe/experts/trtllm_nvfp4_moe.py | 49 +- .../fused_moe/flashinfer_cutlass_moe.py | 2 +- .../layers/fused_moe/fused_batched_moe.py | 4 +- .../layers/fused_moe/fused_humming_moe.py | 4 +- .../layers/fused_moe/fused_marlin_moe.py | 4 +- .../layers/fused_moe/fused_moe.py | 4 +- .../layers/fused_moe/fused_moe_method_base.py | 4 - .../model_executor/layers/fused_moe/layer.py | 24 +- .../layers/fused_moe/oracle/int_wna16.py | 4 +- .../layers/fused_moe/oracle/mxfp4.py | 25 +- .../layers/fused_moe/oracle/nvfp4.py | 8 +- .../fused_moe/prepare_finalize/deepep_ll.py | 64 +- .../flashinfer_nvlink_one_sided.py | 31 +- .../flashinfer_nvlink_two_sided.py | 1 + .../fused_moe/prepare_finalize/naive_dp_ep.py | 1 + .../fused_moe/prepare_finalize/no_dp_ep.py | 1 + .../layers/fused_moe/rocm_aiter_fused_moe.py | 3 +- .../fused_moe/routed_experts_capturer.py | 21 +- .../fused_moe/router/custom_routing_router.py | 4 + .../router/fused_topk_bias_router.py | 2 - .../layers/fused_moe/router/gate_linear.py | 2 +- .../layers/fused_moe/runner/moe_runner.py | 16 +- .../fused_moe/runner/moe_runner_interface.py | 3 +- .../layers/fused_moe/runner/shared_experts.py | 16 +- .../layers/fused_moe/triton_cutlass_moe.py | 4 +- .../fused_moe/unquantized_fused_moe_method.py | 8 +- .../model_executor/layers/fused_moe/utils.py | 10 +- aphrodite/model_executor/layers/linear.py | 5 +- .../layers/mamba/linear_attn.py | 37 +- .../layers/mamba/mamba_utils.py | 3 - aphrodite/model_executor/layers/mhc.py | 130 ++ aphrodite/model_executor/layers/mla.py | 11 +- .../layers/pooler/tokwise/methods.py | 20 +- .../layers/quantization/modelopt.py | 14 +- .../layers/quantization/quark/quark.py | 26 +- .../layers/quantization/quark/quark_moe.py | 18 +- .../quark/schemes/quark_ocp_mx.py | 33 +- .../layers/quantization/utils/mxfp8_utils.py | 24 +- .../utils/nvfp4_emulation_utils.py | 294 +++- .../rotary_embedding/deepseek_scaling_rope.py | 15 +- aphrodite/model_executor/layers/utils.py | 32 +- .../model_loader/base_loader.py | 2 +- .../model_loader/default_loader.py | 33 +- .../model_loader/reload/layerwise.py | 39 +- .../model_loader/reload/utils.py | 31 +- .../models/bailing_moe_linear.py | 30 +- .../model_executor/models/cohere2_vision.py | 15 +- aphrodite/model_executor/models/cohere_asr.py | 77 +- aphrodite/model_executor/models/cohere_moe.py | 485 ++++++ .../model_executor/models/deepseek_v4.py | 187 ++- .../model_executor/models/deepseek_v4_mtp.py | 32 +- aphrodite/model_executor/models/gemma4.py | 31 +- .../model_executor/models/granite4_vision.py | 3 +- aphrodite/model_executor/models/laguna.py | 827 ++++++++++ aphrodite/model_executor/models/llama.py | 15 +- .../model_executor/models/longcat_flash.py | 139 +- aphrodite/model_executor/models/mimo_audio.py | 1269 +++++++++++++++ .../models/{mimo_v2_flash.py => mimo_v2.py} | 24 +- .../model_executor/models/mimo_v2_mtp.py | 346 ++++ .../model_executor/models/mimo_v2_omni.py | 1417 +++++++++++++++++ aphrodite/model_executor/models/minimax_m2.py | 14 +- .../model_executor/models/mistral_eagle.py | 162 ++ aphrodite/model_executor/models/moondream3.py | 1370 ++++++++++++++++ aphrodite/model_executor/models/qwen2.py | 39 +- aphrodite/model_executor/models/registry.py | 11 +- aphrodite/multimodal/cache.py | 24 +- aphrodite/multimodal/registry.py | 3 +- aphrodite/parser/abstract_parser.py | 86 +- aphrodite/platforms/cpu.py | 13 +- aphrodite/platforms/interface.py | 3 + aphrodite/platforms/rocm.py | 13 +- aphrodite/reasoning/__init__.py | 12 + .../cohere_command_reasoning_parser.py | 519 ++++++ aphrodite/reasoning/olmo3_reasoning_parser.py | 38 +- .../reasoning/poolside_v1_reasoning_parser.py | 68 + aphrodite/renderers/base.py | 4 +- aphrodite/renderers/embed_utils.py | 50 +- aphrodite/renderers/hf.py | 432 ++++- aphrodite/sampling_params.py | 25 + aphrodite/tokenizers/deepseek_v4.py | 10 +- aphrodite/tool_parsers/__init__.py | 12 + .../cohere_command_tool_parser.py | 125 ++ .../tool_parsers/deepseekv32_tool_parser.py | 96 +- .../tool_parsers/poolside_v1_tool_parser.py | 554 +++++++ aphrodite/tool_parsers/streaming.py | 189 +++ aphrodite/transformers_utils/config.py | 49 +- .../transformers_utils/configs/__init__.py | 8 + .../transformers_utils/configs/laguna.py | 120 ++ .../configs/mimo_v2_omni.py | 61 + .../transformers_utils/configs/moondream3.py | 152 ++ .../model_arch_config_convertor.py | 39 + .../transformers_utils/processors/__init__.py | 4 + .../processors/mimo_v2_omni.py | 1181 ++++++++++++++ .../processors/moondream3.py | 522 ++++++ aphrodite/utils/flashinfer.py | 33 + aphrodite/utils/multi_stream_utils.py | 66 + aphrodite/v1/attention/backends/cpu_attn.py | 41 +- aphrodite/v1/attention/backends/flashinfer.py | 115 +- .../v1/attention/backends/flex_attention.py | 20 +- .../v1/attention/backends/mla/indexer.py | 7 +- .../backends/mla/prefill/__init__.py | 11 + .../v1/attention/backends/mla/prefill/base.py | 124 ++ .../backends/mla/prefill/flash_attn.py | 174 ++ .../backends/mla/prefill/flashinfer.py | 204 +++ .../backends/mla/prefill/registry.py | 43 + .../backends/mla/prefill/selector.py | 170 ++ .../backends/mla/prefill/trtllm_ragged.py | 172 ++ .../attention/backends/mla/rocm_aiter_mla.py | 2 + .../v1/attention/backends/mla/triton_mla.py | 10 - aphrodite/v1/attention/ops/dcp_alltoall.py | 374 +++-- .../fused_compress_quant_cache.py | 10 +- .../ops/deepseek_v4_ops/fused_indexer_q.py | 55 +- .../fused_inv_rope_fp8_quant.py | 144 +- .../v1/attention/ops/rocm_aiter_mla_sparse.py | 39 +- aphrodite/v1/attention/selector.py | 5 +- aphrodite/v1/core/kv_cache_coordinator.py | 13 +- aphrodite/v1/core/kv_cache_manager.py | 64 +- aphrodite/v1/core/kv_cache_utils.py | 105 +- aphrodite/v1/core/sched/output.py | 2 + aphrodite/v1/core/sched/scheduler.py | 14 +- .../v1/core/single_type_kv_cache_manager.py | 8 +- aphrodite/v1/engine/__init__.py | 7 + aphrodite/v1/engine/async_llm.py | 7 +- aphrodite/v1/engine/core.py | 88 +- aphrodite/v1/engine/input_processor.py | 5 +- aphrodite/v1/engine/logprobs.py | 2 +- aphrodite/v1/kv_cache_interface.py | 4 + aphrodite/v1/kv_offload/abstract.py | 197 --- aphrodite/v1/kv_offload/base.py | 371 +++++ aphrodite/v1/kv_offload/cpu/common.py | 13 + .../{worker/cpu_gpu.py => cpu/gpu_worker.py} | 8 +- aphrodite/v1/kv_offload/cpu/manager.py | 24 +- aphrodite/v1/kv_offload/cpu/policies/arc.py | 4 +- .../cpu/policies/{abstract.py => base.py} | 2 +- aphrodite/v1/kv_offload/cpu/policies/lru.py | 4 +- aphrodite/v1/kv_offload/cpu/spec.py | 13 +- aphrodite/v1/kv_offload/factory.py | 2 +- aphrodite/v1/kv_offload/mediums.py | 68 - aphrodite/v1/kv_offload/reuse_manager.py | 15 +- aphrodite/v1/kv_offload/spec.py | 141 -- aphrodite/v1/kv_offload/worker/worker.py | 2 +- aphrodite/v1/metrics/ray_wrappers.py | 31 +- aphrodite/v1/request.py | 9 + .../v1/sample/logits_processor/__init__.py | 3 - .../v1/sample/logits_processor/builtin.py | 232 +-- aphrodite/v1/sample/metadata.py | 6 + aphrodite/v1/sample/ops/topk_topp_sampler.py | 44 +- aphrodite/v1/sample/rejection_sampler.py | 15 +- aphrodite/v1/sample/sampler.py | 58 + aphrodite/v1/sample/thinking_budget_state.py | 477 ++++++ aphrodite/v1/spec_decode/dflash.py | 13 +- aphrodite/v1/spec_decode/llm_base_proposer.py | 34 +- aphrodite/v1/structured_output/__init__.py | 33 +- aphrodite/v1/structured_output/request.py | 9 +- aphrodite/v1/worker/cpu_model_runner.py | 11 + aphrodite/v1/worker/gpu/block_table.py | 25 +- aphrodite/v1/worker/gpu/cudagraph_utils.py | 38 +- aphrodite/v1/worker/gpu/kv_connector.py | 2 +- aphrodite/v1/worker/gpu/mm/rope.py | 2 +- aphrodite/v1/worker/gpu/model_runner.py | 20 +- .../v1/worker/gpu/model_states/default.py | 4 +- aphrodite/v1/worker/gpu/sample/logprob.py | 140 +- aphrodite/v1/worker/gpu/sample/sampler.py | 22 +- .../worker/gpu/spec_decode/eagle/cudagraph.py | 53 +- .../gpu/spec_decode/eagle/speculator.py | 33 +- aphrodite/v1/worker/gpu_input_batch.py | 36 +- aphrodite/v1/worker/gpu_model_runner.py | 56 +- aphrodite/v1/worker/gpu_ubatch_wrapper.py | 4 +- aphrodite/v1/worker/gpu_worker.py | 40 +- cmake/external_projects/deepgemm.cmake | 27 +- csrc/cpu/cpu_attn.cpp | 103 +- csrc/cpu/cpu_attn_amx.hpp | 217 ++- csrc/cpu/cpu_attn_fp8.hpp | 214 +++ csrc/cpu/cpu_attn_impl.hpp | 38 +- csrc/cpu/cpu_attn_neon.hpp | 9 +- csrc/cpu/cpu_attn_neon_bfmmla.hpp | 3 +- csrc/cpu/cpu_attn_vec.hpp | 133 +- csrc/cpu/cpu_attn_vec16.hpp | 6 +- csrc/cpu/cpu_attn_vxe.hpp | 7 +- csrc/cpu/cpu_types_arm.hpp | 6 + csrc/cpu/cpu_types_vxe.hpp | 6 + csrc/cpu/cpu_types_x86.hpp | 139 ++ csrc/cpu/generate_cpu_attn_dispatch.py | 262 +-- csrc/cpu/torch_bindings.cpp | 16 +- csrc/cutlass_extensions/common.hpp | 45 +- .../w8a8/cutlass/c3x/scaled_mm.cuh | 2 +- ...scaled_mm_blockwise_sm100_fp8_dispatch.cuh | 2 +- .../c3x/scaled_mm_sm100_fp8_dispatch.cuh | 2 +- .../w8a8/fp8/per_token_group_quant.cu | 258 ++- .../w8a8/per_token_group_quant_8bit.h | 10 + csrc/libtorch_stable/torch_bindings.cpp | 8 + csrc/moe/moe_ops.h | 3 - csrc/moe/router_gemm.cu | 52 - csrc/moe/torch_bindings.cpp | 4 - csrc/persistent_topk.cuh | 25 +- csrc/pos_encoding_kernels.cu | 76 +- csrc/topk.cu | 86 +- tools/report_build_time_ninja.py | 4 +- 263 files changed, 19152 insertions(+), 4404 deletions(-) create mode 100644 aphrodite/entrypoints/openai/fingerprint.py create mode 100644 aphrodite/model_executor/layers/fused_moe/configs/E=128,N=2880,device_name=NVIDIA_H100_80GB_HBM3.json create mode 100644 aphrodite/model_executor/models/cohere_moe.py create mode 100644 aphrodite/model_executor/models/laguna.py create mode 100644 aphrodite/model_executor/models/mimo_audio.py rename aphrodite/model_executor/models/{mimo_v2_flash.py => mimo_v2.py} (96%) create mode 100644 aphrodite/model_executor/models/mimo_v2_mtp.py create mode 100644 aphrodite/model_executor/models/mimo_v2_omni.py create mode 100644 aphrodite/model_executor/models/mistral_eagle.py create mode 100644 aphrodite/model_executor/models/moondream3.py create mode 100644 aphrodite/reasoning/cohere_command_reasoning_parser.py create mode 100644 aphrodite/reasoning/poolside_v1_reasoning_parser.py create mode 100644 aphrodite/tool_parsers/cohere_command_tool_parser.py create mode 100644 aphrodite/tool_parsers/poolside_v1_tool_parser.py create mode 100644 aphrodite/tool_parsers/streaming.py create mode 100644 aphrodite/transformers_utils/configs/laguna.py create mode 100644 aphrodite/transformers_utils/configs/mimo_v2_omni.py create mode 100644 aphrodite/transformers_utils/configs/moondream3.py create mode 100644 aphrodite/transformers_utils/processors/mimo_v2_omni.py create mode 100644 aphrodite/transformers_utils/processors/moondream3.py create mode 100644 aphrodite/v1/attention/backends/mla/prefill/__init__.py create mode 100644 aphrodite/v1/attention/backends/mla/prefill/base.py create mode 100644 aphrodite/v1/attention/backends/mla/prefill/flash_attn.py create mode 100644 aphrodite/v1/attention/backends/mla/prefill/flashinfer.py create mode 100644 aphrodite/v1/attention/backends/mla/prefill/registry.py create mode 100644 aphrodite/v1/attention/backends/mla/prefill/selector.py create mode 100644 aphrodite/v1/attention/backends/mla/prefill/trtllm_ragged.py delete mode 100644 aphrodite/v1/kv_offload/abstract.py create mode 100644 aphrodite/v1/kv_offload/base.py create mode 100644 aphrodite/v1/kv_offload/cpu/common.py rename aphrodite/v1/kv_offload/{worker/cpu_gpu.py => cpu/gpu_worker.py} (98%) rename aphrodite/v1/kv_offload/cpu/policies/{abstract.py => base.py} (97%) delete mode 100644 aphrodite/v1/kv_offload/mediums.py delete mode 100644 aphrodite/v1/kv_offload/spec.py create mode 100644 aphrodite/v1/sample/thinking_budget_state.py create mode 100644 csrc/cpu/cpu_attn_fp8.hpp delete mode 100644 csrc/moe/router_gemm.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 8922863cbf..9b41fbb183 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -948,7 +948,6 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA") list(APPEND APHRODITE_MOE_EXT_SRC "csrc/moe/moe_wna16.cu" "csrc/moe/grouped_topk_kernels.cu" - "csrc/moe/router_gemm.cu" "csrc/moe/topk_softplus_sqrt_kernels.cu") endif() diff --git a/aphrodite/_aiter_ops.py b/aphrodite/_aiter_ops.py index 5135800f5c..1b56b8fb57 100644 --- a/aphrodite/_aiter_ops.py +++ b/aphrodite/_aiter_ops.py @@ -2,9 +2,12 @@ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import functools from collections.abc import Callable +from contextlib import contextmanager +from typing import Protocol import torch from torch._ops import OpOverload +from torch.distributed import ProcessGroup import aphrodite.envs as envs from aphrodite.platforms import current_platform @@ -39,6 +42,27 @@ def is_aiter_found() -> bool: IS_AITER_FOUND = is_aiter_found() +class AiterCustomAllreduceProto(Protocol): + max_size: int + world_size: int + fully_connected: bool + + @contextmanager + def capture(self): ... + def close(self) -> None: ... + def fused_ar_rms( + self, + inp: torch.Tensor, + res_inp: torch.Tensor, + *, + w: torch.Tensor, + eps: float, + registered: bool = False, + use_1stage: bool = False, + ) -> tuple[torch.Tensor, torch.Tensor]: ... + def should_custom_ar(self, inp: torch.Tensor) -> bool: ... + + def is_aiter_found_and_supported() -> bool: """Check if AITER library is available and platform supports it. @@ -731,6 +755,55 @@ def _rocm_aiter_per_tensor_quant_impl( return per_tensor_quant_hip(x, scale, quant_dtype) +def _rocm_aiter_fused_allreduce_rmsnorm_impl( + input_: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + aiter_ar = rocm_aiter_ops.get_aiter_allreduce() + assert aiter_ar is not None, "aiter allreduce must be initialized" + + total_bytes = input_.numel() * input_.element_size() + hidden_dim = input_.shape[-1] + token_num = input_.shape[0] + hidden_ok = hidden_dim in (512, 1024, 2048, 4096, 7168) + token_ok = token_num <= 80 + world_size = aiter_ar.world_size + full_nvlink = aiter_ar.fully_connected + + if world_size == 2: + size_ok = True + elif full_nvlink and world_size <= 4: + size_ok = total_bytes < 256 * 1024 + elif full_nvlink and world_size <= 8: + size_ok = total_bytes < 128 * 1024 + else: + size_ok = False + + use_1stage = hidden_ok and token_ok and size_ok + + result = aiter_ar.fused_ar_rms( + input_, + residual, + w=weight, + eps=epsilon, + registered=torch.cuda.is_current_stream_capturing(), + use_1stage=use_1stage, + ) + assert result is not None + return result[0], result[1] + + +def _rocm_aiter_fused_allreduce_rmsnorm_fake( + input_: torch.Tensor, + residual: torch.Tensor, + weight: torch.Tensor, + epsilon: float, +) -> tuple[torch.Tensor, torch.Tensor]: + return torch.empty_like(input_), torch.empty_like(residual) + + def _rocm_aiter_per_tensor_quant_fake( x: torch.Tensor, quant_dtype: torch.dtype, @@ -747,7 +820,7 @@ def _rocm_aiter_per_token_quant_impl( assert quant_dtype in [torch.int8, FP8_DTYPE] out_shape = x.shape - out = torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device) + out = torch.empty(x.shape, dtype=quant_dtype, device=x.device) if scale is None: scale = torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device) dynamic_per_token_scaled_quant( @@ -767,7 +840,7 @@ def _rocm_aiter_per_token_quant_fake( ) -> tuple[torch.Tensor, torch.Tensor]: out_shape = x.shape return ( - torch.empty(x.shape, dtype=FP8_DTYPE, device=x.device), + torch.empty(x.shape, dtype=quant_dtype, device=x.device), torch.empty((*out_shape[:-1], 1), dtype=torch.float32, device=x.device), ) @@ -1157,6 +1230,9 @@ class rocm_aiter_ops: # TODO: Consolidate under _LINEAR_ENABLED _TRITON_UNQUANT_GEMM = envs.APHRODITE_ROCM_USE_AITER_TRITON_GEMM + _ALL_REDUCE_MAX_SIZE: int = 8192 * 1024 * 8 * 2 + _CUSTOM_ALL_REDUCE: AiterCustomAllreduceProto | None = None + @classmethod def refresh_env_variables(cls): """ @@ -1324,6 +1400,40 @@ def is_triton_rotary_embed_enabled(cls) -> bool: def is_triton_gemm_enabled(cls) -> bool: return cls._AITER_ENABLED and cls._TRITON_UNQUANT_GEMM + @classmethod + @if_aiter_supported + def is_tgemm_enabled(cls) -> bool: + from aphrodite.platforms.rocm import on_gfx950 + + return cls.is_linear_enabled() and on_gfx950() + + @classmethod + def initialize_aiter_allreduce(cls, group: ProcessGroup, device: torch.device) -> None: + try: + from aiter.dist.device_communicators.custom_all_reduce import ( + CustomAllreduce as AiterCustomAllreduce, + ) + + cls._CUSTOM_ALL_REDUCE = AiterCustomAllreduce(group, device) + except Exception: + cls._CUSTOM_ALL_REDUCE = None + + @classmethod + def get_aiter_allreduce(cls) -> AiterCustomAllreduceProto | None: + return cls._CUSTOM_ALL_REDUCE + + @classmethod + def destroy_aiter_allreduce(cls) -> None: + if cls._CUSTOM_ALL_REDUCE is not None: + cls._CUSTOM_ALL_REDUCE.close() + cls._CUSTOM_ALL_REDUCE = None + + @classmethod + def get_aiter_allreduce_max_size(cls) -> int | None: + # effective max input size (based on upstream aiter version: v0.1.10.post3) + # https://github.com/ROCm/aiter/blob/6a0e7b26ccf33164785531212cc2ec2cde0b9243/aiter/dist/device_communicators/custom_all_reduce.py#L272-L273 + return int(cls._ALL_REDUCE_MAX_SIZE / 2) + @staticmethod @if_aiter_supported def register_ops_once() -> None: @@ -1514,6 +1624,12 @@ def register_ops_once() -> None: fake_impl=_triton_rotary_embedding_fake, ) + direct_register_custom_op( + op_name="rocm_aiter_fused_allreduce_rmsnorm", + op_func=_rocm_aiter_fused_allreduce_rmsnorm_impl, + fake_impl=_rocm_aiter_fused_allreduce_rmsnorm_fake, + ) + direct_register_custom_op( op_name="fused_mla_dual_rms_norm", op_func=_fused_mla_dual_rms_norm_impl, @@ -1567,6 +1683,10 @@ def get_triton_add_rmsnorm_pad_op() -> OpOverload: def get_triton_rotary_embedding_op() -> OpOverload: return torch.ops.aphrodite.rocm_aiter_triton_rotary_embedding.default + @staticmethod + def get_fused_allreduce_rmsnorm_op() -> OpOverload: + return torch.ops.aphrodite.rocm_aiter_fused_allreduce_rmsnorm.default + @staticmethod def get_fused_mla_dual_rms_norm_op() -> OpOverload: return torch.ops.aphrodite.fused_mla_dual_rms_norm.default diff --git a/aphrodite/_custom_ops.py b/aphrodite/_custom_ops.py index 982b26ed00..f7c201fa70 100644 --- a/aphrodite/_custom_ops.py +++ b/aphrodite/_custom_ops.py @@ -2632,21 +2632,6 @@ def moe_wna16_gemm( ) -def router_gemm_bf16_fp32(input: torch.Tensor, weight: torch.Tensor) -> torch.Tensor: - """bf16 x bf16 -> fp32 GEMM via cuBLAS. weight shape: (N, K).""" - return torch.ops._moe_C.router_gemm_bf16_fp32(input, weight) - - -if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "router_gemm_bf16_fp32"): - - @register_fake("_moe_C::router_gemm_bf16_fp32") - def router_gemm_bf16_fp32_fake( - input: torch.Tensor, - weight: torch.Tensor, - ) -> torch.Tensor: - return torch.empty(input.shape[0], weight.shape[0], dtype=torch.float32, device=input.device) - - def dsv3_router_gemm( hidden_states: torch.Tensor, router_weight: torch.Tensor, @@ -3552,6 +3537,9 @@ def cpu_attn_reshape_and_cache( value_cache: torch.Tensor, slot_mapping: torch.Tensor, isa: str, + k_scale: float = 1.0, + v_scale: float = 1.0, + kv_cache_dtype: str = "auto", ) -> None: torch.ops._C.cpu_attn_reshape_and_cache( key, @@ -3560,6 +3548,9 @@ def cpu_attn_reshape_and_cache( value_cache, slot_mapping, isa, + k_scale, + v_scale, + kv_cache_dtype, ) @@ -3578,6 +3569,9 @@ def cpu_attention_with_kv_cache( softcap: float, scheduler_metadata: torch.Tensor, s_aux: torch.Tensor | None, + k_scale: float = 1.0, + v_scale: float = 1.0, + kv_cache_dtype: str = "auto", ) -> None: torch.ops._C.cpu_attention_with_kv_cache( query, @@ -3595,6 +3589,9 @@ def cpu_attention_with_kv_cache( softcap, scheduler_metadata, s_aux, + k_scale, + v_scale, + kv_cache_dtype, ) diff --git a/aphrodite/compilation/backends.py b/aphrodite/compilation/backends.py index 69b77b5122..2d6a490c96 100644 --- a/aphrodite/compilation/backends.py +++ b/aphrodite/compilation/backends.py @@ -265,6 +265,7 @@ def compile( compilation_counter.num_backend_compilations += 1 compiled_graph = None + handle = None # try to load from the cache compiled_graph = self.load(graph, example_inputs, graph_index, compile_range) @@ -342,7 +343,7 @@ def autograd_cache_key(*args, **kwargs): ) except StopCompiling: assert cache_key is not None - return self.loaded_artifacts[cache_key] + compiled_graph = self.loaded_artifacts[cache_key] if cache_key is not None and compiled_graph is not None: self.loaded_artifacts[cache_key] = compiled_graph diff --git a/aphrodite/compilation/cuda_graph.py b/aphrodite/compilation/cuda_graph.py index a9e7d56912..317e312c4e 100644 --- a/aphrodite/compilation/cuda_graph.py +++ b/aphrodite/compilation/cuda_graph.py @@ -268,8 +268,13 @@ def __call__(self, *args: Any, **kwargs: Any) -> Any | None: # across layers will make the cudagraph capture very slow. # therefore, we only run gc for the first graph, # and disable gc for the rest of the graphs. - stack.enter_context(patch("gc.collect", lambda: None)) - stack.enter_context(patch("torch.accelerator.empty_cache", lambda: None)) + stack.enter_context(patch("gc.collect", lambda *args, **kwargs: None)) + stack.enter_context( + patch( + "torch.accelerator.empty_cache", + lambda *args, **kwargs: None, + ) + ) if self.graph_pool is not None: set_graph_pool_id(self.graph_pool) diff --git a/aphrodite/compilation/decorators.py b/aphrodite/compilation/decorators.py index 23edabddc2..3f0d0e3c02 100644 --- a/aphrodite/compilation/decorators.py +++ b/aphrodite/compilation/decorators.py @@ -32,6 +32,9 @@ from .monitor import monitor_profiling_run, monitor_torch_compile +# shape_id parameter was added to mark_unbacked in PyTorch 2.11.0 +_SUPPORTS_SHAPE_ID = is_torch_equal_or_newer("2.11.0") + if TYPE_CHECKING: # Only added on nightly/2.10 so wrap try: @@ -89,7 +92,7 @@ def support_torch_compile( @overload def support_torch_compile( *, - dynamic_arg_dims: dict[str, int | list[int]] | None, + dynamic_arg_dims: dict[str, int | list[int] | dict[int, str]] | None, ) -> Callable[[type[_T]], type[_T]]: ... @@ -103,7 +106,7 @@ def support_torch_compile( @overload def support_torch_compile( *, - dynamic_arg_dims: dict[str, int | list[int]] | None, + dynamic_arg_dims: dict[str, int | list[int] | dict[int, str]] | None, mark_unbacked_dims: dict[str, int | list[int]] | None, ) -> Callable[[type[_T]], type[_T]]: ... @@ -115,11 +118,10 @@ def support_torch_compile(cls: type[_T]) -> type[_T]: ... def support_torch_compile( cls: type[_T] | None = None, *, - dynamic_arg_dims: dict[str, int | list[int]] | None = None, + dynamic_arg_dims: dict[str, int | list[int] | dict[int, str]] | None = None, mark_unbacked_dims: dict[str, int | list[int]] | None = None, enable_if: Callable[[AphroditeConfig], bool] | None = None, is_encoder: bool = False, - shape_invariants: Callable[..., None] = lambda *args, **kwargs: None, ) -> Callable[[type[_T]], type[_T]] | type[_T]: """ A decorator to add support for compiling the forward method of a class. @@ -141,8 +143,12 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ... ``` `dynamic_arg_dims` is a dictionary that maps argument names to the dynamic - dimensions of the argument. The dynamic dimensions can be either a single - integer or a list of integers. + dimensions of the argument. The value can be: + - int: a single dimension index (e.g., 0) + - list[int]: multiple dimension indices (e.g., [0, 1]) + - dict[int, str]: dimension to shape_id mapping for shape relations + (e.g., {0: "b"}). Dimensions with the same shape_id share the same + unbacked symbol. if `dynamic_arg_dims` is `None`, it is inferred from the type annotation of the `forward` method, based on the following default rules: @@ -189,7 +195,7 @@ def forward(self, x: torch.Tensor, y: Optional[torch.Tensor]): ... torch._check(input_ids.size()[0] == inputs_embeds.size()[0]) This enforces constraints on the symbolic shapes without hardcoding specific values. It is needed for some models to avoid data dependent - errors. + errors and maximize perf when unbacked shapes are used. """ def cls_decorator_helper(cls: type[_T]) -> type[_T]: @@ -233,7 +239,6 @@ def cls_decorator_helper(cls: type[_T]) -> type[_T]: mark_unbacked_dims, enable_if, is_encoder, - shape_invariants, ) if cls is not None: @@ -314,15 +319,13 @@ def _try_load_aot_compiled_fn( def _support_torch_compile( cls: type[_T], - dynamic_arg_dims: dict[str, int | list[int]], + dynamic_arg_dims: dict[str, int | list[int] | dict[int, str]], mark_unbacked_dims: dict[str, int | list[int]] | None = None, enable_if: Callable[[AphroditeConfig], bool] | None = None, is_encoder: bool = False, - shape_invariants: Callable[..., None] = lambda *args, **kwargs: None, ) -> type[_T]: - """ - A decorator to add support for compiling the forward method of a class. - """ + """Internal implementation of support_torch_compile decorator.""" + if TorchCompileWithNoGuardsWrapper in cls.__bases__: # support decorating multiple times return cls @@ -381,7 +384,8 @@ def __init__( if self.do_not_compile: return - self._check_shape_invariants = shape_invariants + self._dynamic_arg_dims = dynamic_arg_dims + self.was_aot_compile_fn_loaded_from_disk = False compilation_counter.num_models_seen += 1 self.compiled = False @@ -396,43 +400,70 @@ def __init__( cls.__init__ = __init__ def _mark_dynamic_inputs(mod: type[_T], ds_type: DynamicShapesType, *args: Any, **kwargs: Any) -> None: - def mark_dynamic(arg: torch.Tensor, dims: list[int]) -> None: + def mark_dynamic(arg: torch.Tensor, dim_shape_pairs: list[tuple[int, str | None]]) -> None: if ds_type == DynamicShapesType.UNBACKED: if is_torch_equal_or_newer("2.10.0"): - for dim in dims: - torch._dynamo.decorators.mark_unbacked(arg, dim, hint_override=arg.size()[dim]) + for dim, shape_id in dim_shape_pairs: + if shape_id is not None: + if not _SUPPORTS_SHAPE_ID: + raise RuntimeError(f"shape_id='{shape_id}' requires PyTorch >= 2.11.0") + torch._dynamo.decorators.mark_unbacked( + arg, + dim, + hint_override=arg.size()[dim], + shape_id=shape_id, + ) + else: + torch._dynamo.decorators.mark_unbacked( + arg, + dim, + hint_override=arg.size()[dim], + ) else: + # For older versions, we can't use hint_override or shape_id + dims = [dim for dim, _ in dim_shape_pairs] torch._dynamo.decorators.mark_unbacked(arg, dims) else: + dims = [dim for dim, _ in dim_shape_pairs] torch._dynamo.mark_dynamic(arg, dims) sig = inspect.signature(mod.__class__.forward) # type: ignore[attr-defined] bound_args = sig.bind(mod, *args, **kwargs) bound_args.apply_defaults() - for k, dims in dynamic_arg_dims.items(): + + # Normalize dynamic_arg_dims to dict[str, dict[int, str | None]] + normalized_dims: dict[str, dict[int, str | None]] = {} + for k, v in dynamic_arg_dims.items(): + if isinstance(v, dict): + normalized_dims[k] = {dim: shape_id for dim, shape_id in v.items()} + elif isinstance(v, int): + normalized_dims[k] = {v: None} + else: + normalized_dims[k] = {d: None for d in v} + + for k, dim_to_shape_id in normalized_dims.items(): arg = bound_args.arguments.get(k) if arg is not None: - dims = [dims] if isinstance(dims, int) else dims + dims = list(dim_to_shape_id.keys()) + if isinstance(arg, torch.Tensor): - # In case dims is specified with negative indexing - dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] - mark_dynamic(arg, dims) + dim_shape_pairs = [(arg.ndim + d if d < 0 else d, dim_to_shape_id.get(d)) for d in dims] + mark_dynamic(arg, dim_shape_pairs) elif isinstance(arg, IntermediateTensors): for tensor in arg.tensors.values(): - # In case dims is specified with negative indexing - dims = [tensor.ndim + dim if dim < 0 else dim for dim in dims] - mark_dynamic(tensor, dims) + dim_shape_pairs = [(tensor.ndim + d if d < 0 else d, dim_to_shape_id.get(d)) for d in dims] + mark_dynamic(tensor, dim_shape_pairs) else: raise ValueError(f"Unsupported dynamic dimensions {dims} for argument {k} with type {type(arg)}.") + if mark_unbacked_dims: - for k, dims in mark_unbacked_dims.items(): + for k, dims_val in mark_unbacked_dims.items(): arg = bound_args.arguments.get(k) if arg is not None: - dims = [dims] if isinstance(dims, int) else dims + dims = [dims_val] if isinstance(dims_val, int) else list(dims_val) if isinstance(arg, torch.Tensor): - # In case dims is specified with negative indexing - dims = [arg.ndim + dim if dim < 0 else dim for dim in dims] + dims = [arg.ndim + d if d < 0 else d for d in dims] if is_torch_equal_or_newer("2.10.0"): for dim in dims: torch._dynamo.decorators.mark_unbacked(arg, dim, hint_override=arg.size()[dim]) diff --git a/aphrodite/compilation/passes/fusion/act_quant_fusion.py b/aphrodite/compilation/passes/fusion/act_quant_fusion.py index 247efc2296..56594d3a4f 100644 --- a/aphrodite/compilation/passes/fusion/act_quant_fusion.py +++ b/aphrodite/compilation/passes/fusion/act_quant_fusion.py @@ -183,6 +183,7 @@ def __init__( is_scale_transposed: bool = False, is_e8m0: bool = False, is_tma_aligned: bool = False, + match_aiter: bool = False, ) -> None: super().__init__(quant_key) self.quant_matcher = MatcherQuantFP8( diff --git a/aphrodite/compilation/passes/fusion/allreduce_rms_fusion.py b/aphrodite/compilation/passes/fusion/allreduce_rms_fusion.py index c8c74412d4..c70b14d8b2 100644 --- a/aphrodite/compilation/passes/fusion/allreduce_rms_fusion.py +++ b/aphrodite/compilation/passes/fusion/allreduce_rms_fusion.py @@ -12,12 +12,14 @@ from torch._inductor.pattern_matcher import PatternMatcherPass import aphrodite.ir.ops +from aphrodite._aiter_ops import rocm_aiter_ops from aphrodite.compilation.passes.fusion.rms_quant_fusion import ( _rms_input_weight_dtype_match, ) from aphrodite.config import AphroditeConfig from aphrodite.config.utils import Range from aphrodite.distributed import get_tp_group, tensor_model_parallel_all_reduce +from aphrodite.distributed.device_communicators.custom_all_reduce import CustomAllreduce from aphrodite.distributed.parallel_state import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, @@ -31,7 +33,12 @@ direct_register_custom_op, ) -from ..aphrodite_inductor_pass import AphroditeInductorPass, AphroditePatternMatcherPass +from ..aphrodite_inductor_pass import ( + AphroditeFusionPatternMatcherPass, + AphroditeInductorPass, + AphroditePatternMatcherPass, + AphroditePatternReplacement, +) from ..inductor_pass import enable_fake_mode from .matcher_utils import MatcherFusedAddRMSNorm, MatcherQuantFP8 @@ -845,3 +852,192 @@ def __del__(self) -> None: return with contextlib.suppress(Exception): destroy_fi_ar_workspace() + + +# TODO: make BasePattern to inherit from AphroditePatternReplacement +class AiterAllreduceFusedRMSNormPattern(BasePattern, AphroditePatternReplacement): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str | None, + use_aiter_rmsnorm: bool = True, + ) -> None: + super().__init__(dtype, device) + self.dtype = dtype + self.epsilon = epsilon + self.FUSED_AR_RMSNORM_OP = rocm_aiter_ops.get_fused_allreduce_rmsnorm_op() + + def get_inputs(self) -> list[torch.Tensor]: + return [self.empty(5, 16), self.empty(16)] + + @property + def pattern(self): + def _pattern(input: torch.Tensor, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + allreduce_output = tensor_model_parallel_all_reduce(input) + rms = aphrodite.ir.ops.rms_norm(allreduce_output, weight, self.epsilon) + + return rms, allreduce_output + + return _pattern + + @property + def replacement(self): + def _replacement(input: torch.Tensor, weight: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + residual = torch.empty_like(input) + allreduce = self.FUSED_AR_RMSNORM_OP( + input_=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + return allreduce[0], allreduce[1] + + return _replacement + + +class AiterAllreduceFusedAddRMSNormPattern(BasePattern, AphroditePatternReplacement): + def __init__( + self, + epsilon: float, + dtype: torch.dtype, + device: str | None, + use_aiter_rmsnorm: bool = True, + ) -> None: + super().__init__(dtype, device) + self.epsilon = epsilon + self.dtype = dtype + self.rmsnorm_matcher = MatcherFusedAddRMSNorm(epsilon, match_rocm_aiter=use_aiter_rmsnorm) + self.FUSED_AR_RMSNORM_OP = rocm_aiter_ops.get_fused_allreduce_rmsnorm_op() + + def get_inputs(self) -> list[torch.Tensor]: + input, residual, weight = self.rmsnorm_matcher.inputs() + + return [residual, input.to(self.dtype), weight] + + @property + def pattern(self): + def _pattern( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce_output = tensor_model_parallel_all_reduce(input) + rms, residual = self.rmsnorm_matcher(allreduce_output, weight, residual) + + return rms, residual + + return _pattern + + @property + def replacement(self): + def _replacement( + residual: torch.Tensor, input: torch.Tensor, weight: torch.Tensor + ) -> tuple[torch.Tensor, torch.Tensor]: + allreduce = self.FUSED_AR_RMSNORM_OP( + input_=input, + residual=residual, + weight=weight, + epsilon=self.epsilon, + ) + return allreduce[0], allreduce[1] + + return _replacement + + +class RocmAiterAllReduceFusionPass(AphroditeFusionPatternMatcherPass): + def __init__(self, config: AphroditeConfig) -> None: + super().__init__(config, "rocm_aiter_allreduce_fusion_pass") + self.disabled = True + self.tp_size = get_tensor_model_parallel_world_size() + if self.tp_size <= 1: + logger.warning_once("AllReduce fusion pass is disabled for tp_size <= 1.") + return + + if config.model_config is None: + logger.warning_once("AllReduce fusion pass is disabled for missing model_config.") + return + + device_comm = get_tp_group().device_communicator + if device_comm is None: + logger.warning_once("Device communicator is required.") + return + + ca_comm = getattr(device_comm, "ca_comm", None) + if ca_comm is None: + logger.warning_once("Custom Allreduce is required.") + return + self.ca_comm = ca_comm + + assert isinstance(ca_comm, CustomAllreduce) + + group = get_tp_group().cpu_group + rocm_aiter_ops.initialize_aiter_allreduce(group, self.device) + hidden_dim = config.model_config.get_hidden_size() + element_size = torch.tensor([], dtype=self.model_dtype).element_size() + max_size = rocm_aiter_ops.get_aiter_allreduce_max_size() + if max_size is None: + logger.warning("AITER allreduce fusion must be initialized") + return + + # Aiter's fused_allreduce_rmsnorm kernel dispatches on hidden_dim. + # Before aiter v0.1.12 the launcher was template-specialized on HIDDEN_DIM + # and silently no-op'd for sizes outside {512, 1024, 2048, 4096}. From v0.1.12 + # hidden_dim is a runtime argument. Detect the older API via the missing + # `_pool` attribute and skip fusion for unsupported sizes. + # Ref (old kernel): https://github.com/ROCm/aiter/blob/6a0e7b26ccf33164785531212cc2ec2cde0b9243/csrc/include/custom_all_reduce.cuh#L2590 + aiter_ar = rocm_aiter_ops.get_aiter_allreduce() + _AITER_OLD_FUSED_AR_RMS_HIDDEN = (512, 1024, 2048, 4096) + if aiter_ar is not None and not hasattr(aiter_ar, "_pool") and hidden_dim not in _AITER_OLD_FUSED_AR_RMS_HIDDEN: + logger.warning_once( + "AITER allreduce-rmsnorm fusion disabled: aiter<0.1.12 " + "only supports hidden_dim in %s; got %d. Upgrade aiter to " + ">=0.1.12 to enable fusion for this model.", + _AITER_OLD_FUSED_AR_RMS_HIDDEN, + hidden_dim, + ) + # Tear down aiter's custom-allreduce so its IPC handles don't + # race with aphrodite's ca_comm on the unfused fallback path. + with contextlib.suppress(Exception): + rocm_aiter_ops.destroy_aiter_allreduce() + return + + max_token_num = max_size // (hidden_dim * element_size) + self.max_token_num = min( + max_token_num, + config.scheduler_config.max_num_batched_tokens, + ) + + for epsilon in [1e-5, 1e-6]: + self.register( + AiterAllreduceFusedRMSNormPattern( + epsilon, + self.model_dtype, + self.device, + ) + ) + self.register( + AiterAllreduceFusedAddRMSNormPattern( + epsilon, + self.model_dtype, + self.device, + ) + ) + + # WARNING: This is a hack to clear the pattern matcher cache + # and allow multiple values of epsilon. + torch._inductor.pattern_matcher._seen_patterns.clear() + + self.disabled = False + + self.dump_patterns(config, self.pm_pass) + + def is_applicable_for_range(self, compile_range: Range) -> bool: + if self.disabled: + logger.warning_once("AllReduce fusion pass is disabled.") + return False + return bool(compile_range.end <= self.max_token_num) + + def __del__(self) -> None: + if getattr(self, "disabled", True): + return + with contextlib.suppress(Exception): + rocm_aiter_ops.destroy_aiter_allreduce() diff --git a/aphrodite/compilation/passes/fusion/collective_fusion.py b/aphrodite/compilation/passes/fusion/collective_fusion.py index cb5a6411a7..8eff40bfb5 100644 --- a/aphrodite/compilation/passes/fusion/collective_fusion.py +++ b/aphrodite/compilation/passes/fusion/collective_fusion.py @@ -1,8 +1,12 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +from collections.abc import Callable +from contextlib import suppress + import torch import torch._inductor.pattern_matcher as pm +import torch.distributed.distributed_c10d as c10d import torch.fx as fx from torch._inductor.pattern_matcher import PatternMatcherPass from torch.distributed._symmetric_memory import enable_symm_mem_for_group @@ -15,8 +19,14 @@ ) from aphrodite.logger import init_logger from aphrodite.platforms import current_platform +from aphrodite.utils.torch_utils import direct_register_custom_op -from ..aphrodite_inductor_pass import AphroditeInductorPass, AphroditePatternMatcherPass +from ..aphrodite_inductor_pass import ( + AphroditeFusionPatternMatcherPass, + AphroditeInductorPass, + AphroditePatternMatcherPass, + AphroditePatternReplacement, +) from ..inductor_pass import enable_fake_mode FP8_DTYPE = current_platform.fp8_dtype() @@ -24,6 +34,172 @@ logger = init_logger(__name__) +def _flashinfer_scaled_mm_out( + A: torch.Tensor, + B: torch.Tensor, + *, + scale_a: torch.Tensor, + scale_b: torch.Tensor, + out: torch.Tensor, + bias: torch.Tensor | None = None, + scale_result: torch.Tensor | None = None, + out_dtype: torch.dtype | None = None, + use_fast_accum: bool = False, +) -> None: + # Import lazily to avoid a circular import during module initialization + # when docs or other tooling import the pass without FlashInfer. + from aphrodite.utils.flashinfer import flashinfer_scaled_fp8_mm_out + + assert bias is None, "FlashInfer symm_mem adapter does not support bias" + assert scale_result is None, "FlashInfer symm_mem adapter does not support result scaling" + assert not use_fast_accum, "FlashInfer symm_mem adapter does not support use_fast_accum" + assert A.ndim == 2 and B.ndim == 2 and out.ndim == 2, "FlashInfer symm_mem adapter expects 2D inputs and output" + assert scale_a.numel() == 1 and scale_b.numel() == 1, ( + "FlashInfer symm_mem adapter only supports tensor-wise FP8 scales" + ) + + flashinfer_scaled_fp8_mm_out( + A, + B, + scale_a, + scale_b, + out=out, + out_dtype=out_dtype or out.dtype, + ) + + +def fused_flashinfer_scaled_matmul_reduce_scatter_fake( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: str, + output_shape: list[int], + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + world_size = c10d._resolve_process_group(group_name).size() + result_shape = list(output_shape) + result_shape[orig_scatter_dim] //= world_size + return torch.empty( + result_shape, + dtype=out_dtype or torch.bfloat16, + device=A.device, + ) + + +def fused_flashinfer_scaled_matmul_reduce_scatter( + A: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + reduce_op: str, + orig_scatter_dim: int, + scatter_dim_after_maybe_reshape: int, + group_name: str, + output_shape: list[int], + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + assert orig_scatter_dim == 0 and scatter_dim_after_maybe_reshape == 0, ( + "FlashInfer symm_mem adapter currently only supports scatter_dim=0" + ) + world_size = c10d._resolve_process_group(group_name).size() + assert A.ndim == 2 and B.ndim == 2, "FlashInfer symm_mem adapter expects 2D inputs" + assert A.is_contiguous(), "FlashInfer symm_mem adapter expects contiguous A" + assert A_scale.numel() == 1 and B_scale.numel() == 1, ( + "FlashInfer symm_mem adapter only supports tensor-wise FP8 scales" + ) + assert A.shape[0] % world_size == 0, "FlashInfer symm_mem adapter expects M divisible by world size" + + kwargs = { + "scale_b": B_scale, + "bias": None, + "scale_result": None, + "out_dtype": out_dtype, + "use_fast_accum": False, + } + return torch.distributed._symmetric_memory._fused_scaled_matmul_reduce_scatter_impl( + mm_out_op=_flashinfer_scaled_mm_out, + A=A, + B=B, + A_scale=A_scale, + kwargs=kwargs, + out_dtype=out_dtype, + reduce_op=reduce_op, + orig_scatter_dim=orig_scatter_dim, + scatter_dim_after_maybe_reshape=scatter_dim_after_maybe_reshape, + group_name=group_name, + output_shape=output_shape, + ) + + +def fused_all_gather_flashinfer_scaled_matmul_fake( + A_shard: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + gather_dim: int, + group_name: str, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + world_size = c10d._resolve_process_group(group_name).size() + output_shape = list(A_shard.shape) + output_shape[gather_dim] *= world_size + output_shape[-1] = B.shape[1] + return torch.empty( + output_shape, + dtype=out_dtype or torch.bfloat16, + device=A_shard.device, + ) + + +def fused_all_gather_flashinfer_scaled_matmul( + A_shard: torch.Tensor, + B: torch.Tensor, + A_scale: torch.Tensor, + B_scale: torch.Tensor, + gather_dim: int, + group_name: str, + out_dtype: torch.dtype | None = None, +) -> torch.Tensor: + assert gather_dim == 0, "FlashInfer symm_mem adapter currently only supports gather_dim=0" + _, outputs = torch.distributed._symmetric_memory._fused_all_gather_matmul_impl( + mm_out_op=_flashinfer_scaled_mm_out, + A_shard=A_shard, + Bs=[B], + A_scale=A_scale, + kwargs_list=[ + { + "scale_b": B_scale, + "bias": None, + "scale_result": None, + "out_dtype": out_dtype, + "use_fast_accum": False, + } + ], + out_dtypes=[out_dtype], + gather_dim=gather_dim, + group_name=group_name, + return_A=False, + ) + return outputs[0] + + +direct_register_custom_op( + op_name="fused_flashinfer_scaled_matmul_reduce_scatter", + op_func=fused_flashinfer_scaled_matmul_reduce_scatter, + fake_impl=fused_flashinfer_scaled_matmul_reduce_scatter_fake, +) + +direct_register_custom_op( + op_name="fused_all_gather_flashinfer_scaled_matmul", + op_func=fused_all_gather_flashinfer_scaled_matmul, + fake_impl=fused_all_gather_flashinfer_scaled_matmul_fake, +) + + class BasePattern: def __init__(self, dtype: torch.dtype, device: str | None) -> None: self.dtype = dtype @@ -343,29 +519,145 @@ def replacement( pm.register_replacement(pattern, replacement, self.get_inputs(), pm.fwd_only, pm_pass) -class AsyncTPPass(AphroditePatternMatcherPass): +class FlashInferBMMFP8ReduceScatterPattern(BasePattern, AphroditePatternReplacement[..., torch.Tensor]): + def get_inputs(self) -> list[torch.Tensor]: + a_2d = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE) + b_2d = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE).contiguous().transpose(0, 1) + a_scale = torch.empty([1], device=self.device, dtype=torch.float32) + b_scale = torch.empty([1], device=self.device, dtype=torch.float32) + return [a_2d, b_2d, a_scale, b_scale] + + @property + def pattern(self) -> Callable[..., torch.Tensor]: + def _pattern( + a_2d: torch.Tensor, + b_2d: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + ) -> torch.Tensor: + bmm = torch.ops.aphrodite.bmm_fp8.default( + torch.ops.aten.unsqueeze.default(a_2d, 0), + torch.ops.aten.unsqueeze.default(b_2d, 0), + a_scale, + b_scale, + self.dtype, + "auto", + ) + output = torch.ops.aten.reshape.default(bmm, list(bmm.shape[1:])) + return torch.ops.aphrodite.reduce_scatter.default( + output, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name, + ) + + return _pattern + + @property + def replacement(self) -> Callable[..., torch.Tensor]: + def _replacement( + a_2d: torch.Tensor, + b_2d: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + ) -> torch.Tensor: + return torch.ops.aphrodite.fused_flashinfer_scaled_matmul_reduce_scatter.default( + a_2d, + b_2d, + a_scale, + b_scale, + "sum", + 0, + 0, + self.tp.device_group.group_name, + [a_2d.shape[0], b_2d.shape[1]], + self.dtype, + ) + + return _replacement + + +class FlashInferAllGatherBMMFP8Pattern(BasePattern, AphroditePatternReplacement[..., torch.Tensor]): + def get_inputs(self) -> list[torch.Tensor]: + a_shard_2d = torch.empty([8, 16], device=self.device, dtype=FP8_DTYPE) + b_2d = torch.empty([16, 16], device=self.device, dtype=FP8_DTYPE).contiguous().transpose(0, 1) + a_scale = torch.empty([1], device=self.device, dtype=torch.float32) + b_scale = torch.empty([1], device=self.device, dtype=torch.float32) + return [a_shard_2d, b_2d, a_scale, b_scale] + + @property + def pattern(self) -> Callable[..., torch.Tensor]: + def _pattern( + a_shard_2d: torch.Tensor, + b_2d: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + ) -> torch.Tensor: + all_gather = torch.ops.aphrodite.all_gather.default( + a_shard_2d, + dim=0, + world_size=self.tp_size, + group_name=self.tp.unique_name, + ) + return torch.ops.aphrodite.bmm_fp8.default( + torch.ops.aten.unsqueeze.default(all_gather, 0), + torch.ops.aten.unsqueeze.default(b_2d, 0), + a_scale, + b_scale, + self.dtype, + "auto", + ) + + return _pattern + + @property + def replacement(self) -> Callable[..., torch.Tensor]: + def _replacement( + a_shard_2d: torch.Tensor, + b_2d: torch.Tensor, + a_scale: torch.Tensor, + b_scale: torch.Tensor, + ) -> torch.Tensor: + fused = torch.ops.aphrodite.fused_all_gather_flashinfer_scaled_matmul.default( + a_shard_2d, + b_2d, + a_scale, + b_scale, + 0, + self.tp.device_group.group_name, + self.dtype, + ) + return torch.ops.aten.unsqueeze.default(fused, 0) + + return _replacement + + +class AsyncTPPass(AphroditeFusionPatternMatcherPass): @enable_fake_mode def __init__(self, config: AphroditeConfig) -> None: - super().__init__(config) + super().__init__(config, pass_name="async_tp_pass") - # Enable symmetric memory for the TP process group enable_symm_mem_for_group(get_tp_group().device_group.group_name) - self.patterns: PatternMatcherPass = PatternMatcherPass(pass_name="async_tp_pass") - GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns) + GEMMReduceScatterPattern(self.model_dtype, self.device).register(self.pm_pass) - AllGatherGEMMPattern(self.model_dtype, self.device).register(self.patterns) + AllGatherGEMMPattern(self.model_dtype, self.device).register(self.pm_pass) # These fusions are enabled only for bfloat16 models because # `scaled_mm` or `cutlass_scaled_mm` with per-token (row-wise) scaling # only supports bfloat16 as the output dtype. if self.model_dtype == torch.bfloat16: - ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns) - AllGatherScaledMMPattern(self.model_dtype, self.device).register(self.patterns) + ScaledMMReduceScatterPattern(self.model_dtype, self.device).register(self.pm_pass) + AllGatherScaledMMPattern(self.model_dtype, self.device).register(self.pm_pass) - CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(self.patterns) - AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(self.patterns) + CutlassScaledMMReduceScatterPattern(self.model_dtype, self.device).register(self.pm_pass) + AllGatherCutlassScaledMMPattern(self.model_dtype, self.device).register(self.pm_pass) + with suppress(ImportError): + import aphrodite.utils.flashinfer # noqa: F401 + if hasattr(torch.ops.aphrodite, "bmm_fp8"): + self.register(FlashInferAllGatherBMMFP8Pattern(self.model_dtype, self.device)) + self.register(FlashInferBMMFP8ReduceScatterPattern(self.model_dtype, self.device)) - self.dump_patterns(config, self.patterns) + self.dump_patterns(config, self.pm_pass) def is_applicable_for_range(self, compile_range: Range) -> bool: # This pass is applied on top of the sequence parallelism pass, @@ -377,5 +669,6 @@ def is_applicable_for_range(self, compile_range: Range) -> bool: @AphroditeInductorPass.time_and_log def __call__(self, graph: fx.Graph) -> None: - self.matched_count = self.patterns.apply(graph) + self.matched_count = self.pm_pass.apply(graph) + AphroditePatternMatcherPass.match_table[self.pass_name] += self.matched_count logger.debug("Replaced %s patterns", self.matched_count) diff --git a/aphrodite/compilation/passes/fusion/sequence_parallelism.py b/aphrodite/compilation/passes/fusion/sequence_parallelism.py index 669f91020b..6b908a12ec 100644 --- a/aphrodite/compilation/passes/fusion/sequence_parallelism.py +++ b/aphrodite/compilation/passes/fusion/sequence_parallelism.py @@ -31,6 +31,7 @@ # Only apply sequence parallelism for models with hidden_size >= threshold SP_MIN_HIDDEN_SIZE: dict[int, int] = { 90: 8192, # H100: only for models with hidden_size >= 8192 + 100: 8192, # Blackwell family: only for models with hidden_size >= 8192 } # Min size per GPU per device capability for sequence parallelism @@ -38,6 +39,8 @@ # This ensures the threshold scales appropriately with tensor parallelism SP_MIN_PER_GPU_SIZE_MB: dict[int, float] = { 90: 8, # 8MB per GPU for H100 + # Use a more conservative threshold on Blackwell so TP8 starts later. + 100: 32, } @@ -67,7 +70,12 @@ def get_sequence_parallelism_threshold( capability = current_platform.get_device_capability() if capability is None: return None - device_capability = capability.to_int() + + # Collapse Blackwell variants (sm100/sm103/...) into one policy bucket. + if current_platform.is_device_capability_family(100): + device_capability = 100 + else: + device_capability = capability.to_int() # Check if device has configured thresholds min_hidden_size = SP_MIN_HIDDEN_SIZE.get(device_capability) diff --git a/aphrodite/compilation/passes/pass_manager.py b/aphrodite/compilation/passes/pass_manager.py index da4900e48b..b69128386e 100644 --- a/aphrodite/compilation/passes/pass_manager.py +++ b/aphrodite/compilation/passes/pass_manager.py @@ -18,6 +18,9 @@ from .ir.lowering_pass import AphroditeIRLoweringPass if rocm_aiter_ops.is_enabled(): + from .fusion.allreduce_rms_fusion import ( + RocmAiterAllReduceFusionPass, + ) from .fusion.rocm_aiter_fusion import ( MLADualRMSNormFusionPass, RocmAiterRMSNormQuantFusionPass, @@ -137,7 +140,10 @@ def configure(self, config: AphroditeConfig) -> None: self.passes += [AsyncTPPass(config)] if self.pass_config.fuse_allreduce_rms: - self.passes += [AllReduceFusionPass(config)] + if rocm_aiter_ops.is_enabled(): + self.passes += [RocmAiterAllReduceFusionPass(config)] + else: + self.passes += [AllReduceFusionPass(config)] if self.pass_config.fuse_minimax_qk_norm: self.passes += [MiniMaxQKNormPass(config)] diff --git a/aphrodite/compilation/wrapper.py b/aphrodite/compilation/wrapper.py index 87810e33da..8deda0fa8f 100644 --- a/aphrodite/compilation/wrapper.py +++ b/aphrodite/compilation/wrapper.py @@ -53,12 +53,6 @@ class TorchCompileWithNoGuardsWrapper: since we drop all guards. """ - def check_invariants_and_forward(self, *args: Any, **kwargs: Any) -> Any: - assert hasattr(self, "_check_shape_invariants") - self._check_shape_invariants(*args, **kwargs) - - return self.forward(*args, **kwargs) - def _call_with_optional_nvtx_range(self, callable_fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> Any: if self.layerwise_nvtx_tracing_enabled: args_list = list(args) @@ -109,6 +103,8 @@ def __init__( "compilation_config.dynamic_shapes_config.evaluate_guards requires APHRODITE_USE_BYTECODE_HOOK=0. " ) + assert ds_type != DynamicShapesType.UNBACKED, "UNBACKED dynamic shapes do not add guards" + options["guard_filter_fn"] = lambda x: [entry.guard_type == "SHAPE_ENV" for entry in x] else: if hasattr(torch.compiler, "skip_all_guards_unsafe"): @@ -121,19 +117,6 @@ def __init__( compiled_ptr: Any = self.forward # Validate that unbacked dynamic shapes require APHRODITE_USE_BYTECODE_HOOK=False - if ds_type == DynamicShapesType.UNBACKED: - # reason is that bytecode does torch._dynamo.eval_frame. - # remove_from_cache(self.original_code_object()) to force a new - # re-compilation. And if we use - # compiled_ptr = self.check_invariants_and_forward - # it will reset all entries. - assert not envs.APHRODITE_USE_BYTECODE_HOOK, ( - "UNBACKED dynamic shapes requires APHRODITE_USE_BYTECODE_HOOK=0. " - ) - assert not self.evaluate_guards, "UNBACKED dynamic shapes do not add guards" - - compiled_ptr = self.check_invariants_and_forward - # Apply the constrain_to_fx_strides patch before first compilation. # This covers STOCK_TORCH_COMPILE and DYNAMO_ONCE paths. The APHRODITE # compile paths call this from their own compile() methods too. diff --git a/aphrodite/config/aphrodite.py b/aphrodite/config/aphrodite.py index 306a73b4d5..b8b6b165f2 100644 --- a/aphrodite/config/aphrodite.py +++ b/aphrodite/config/aphrodite.py @@ -123,6 +123,15 @@ def enable_allreduce_rms_fusion(cfg: "AphroditeConfig") -> bool: from aphrodite.platforms import current_platform from aphrodite.utils.flashinfer import has_flashinfer + if current_platform.is_rocm(): + from aphrodite._aiter_ops import rocm_aiter_ops + + return ( + rocm_aiter_ops.is_enabled() + and rocm_aiter_ops.is_rmsnorm_enabled() + and cfg.parallel_config.tensor_parallel_size > 1 + ) + return ( cfg.parallel_config.tensor_parallel_size > 1 and current_platform.is_cuda() @@ -1331,6 +1340,10 @@ def _set_cudagraph_sizes(self): cudagraph_capture_sizes = [1, 2, 4] + list(range(8, 256, 8)) + list( range(256, max_graph_size + 1, 16)) + `max_num_batched_tokens` is also appended to the list if it fits + within `max_cudagraph_capture_size`, so the max batch size is captured + even when off-stride. + In the end, `aphrodite_config.compilation_config.cudagraph_capture_sizes` will be the final sizes to capture cudagraph (in ascending order). @@ -1402,6 +1415,9 @@ def _set_cudagraph_sizes(self): if max_cudagraph_capture_size >= 256: # Step size 16 for larger batch sizes cudagraph_capture_sizes += list(range(256, max_cudagraph_capture_size + 1, 16)) + # ensure max_num_tokens is captured if within max capture size + if max_num_tokens <= max_cudagraph_capture_size and max_num_tokens not in cudagraph_capture_sizes: + cudagraph_capture_sizes.append(max_num_tokens) # de-duplicate and sort the sizes cudagraph_capture_sizes = sorted(set(cudagraph_capture_sizes)) @@ -1466,10 +1482,15 @@ def _set_compile_ranges(self): if compile_range_end is not None: computed_compile_ranges_endpoints.append(compile_range_end) - # Add the compile ranges for flashinfer + # Add the compile ranges for flashinfer/aiter. if compilation_config.pass_config.fuse_allreduce_rms: tp_size = self.parallel_config.tensor_parallel_size - max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) + from aphrodite._aiter_ops import rocm_aiter_ops + + if rocm_aiter_ops.is_enabled(): + max_size = rocm_aiter_ops.get_aiter_allreduce_max_size() + else: + max_size = compilation_config.pass_config.flashinfer_max_size(tp_size) if max_size is not None: assert isinstance(self.model_config.dtype, torch.dtype) max_token_num = max_size // (self.model_config.get_hidden_size() * self.model_config.dtype.itemsize) @@ -1718,6 +1739,18 @@ def validate_block_size(self) -> None: "in the middle of a mm input" ) + @model_validator(mode="after") + def validate_nvfp4_kv_cache_with_mla(self) -> "AphroditeConfig": + if self.model_config is None: + return self + if self.cache_config.cache_dtype == "nvfp4" and self.model_config.use_mla: + raise ValueError( + "nvfp4 KV cache is not supported with MLA (Multi-head Latent " + "Attention) backends. Please use a different --kv-cache-dtype " + "(e.g., 'fp8' or 'auto') for MLA models such as DeepSeek." + ) + return self + @model_validator(mode="after") def validate_mamba_block_size(self) -> "AphroditeConfig": if self.model_config is None: diff --git a/aphrodite/config/attention.py b/aphrodite/config/attention.py index d131e3bd9f..48bbe88056 100644 --- a/aphrodite/config/attention.py +++ b/aphrodite/config/attention.py @@ -6,8 +6,12 @@ from pydantic import field_validator from aphrodite.config.utils import config +from aphrodite.logger import init_logger +from aphrodite.v1.attention.backends.mla.prefill.registry import MLAPrefillBackendEnum from aphrodite.v1.attention.backends.registry import AttentionBackendEnum +logger = init_logger(__name__) + @config class AttentionConfig: @@ -33,7 +37,7 @@ class AttentionConfig: and buffers can be pre-allocated to avoid inflating the memory estimate.""" use_cudnn_prefill: bool = False - """Whether to use cudnn prefill.""" + """Deprecated: cuDNN prefill backend has been removed.""" use_trtllm_ragged_deepseek_prefill: bool = False """Whether to use TRTLLM ragged deepseek prefill.""" @@ -42,18 +46,27 @@ class AttentionConfig: """If set to True/False, use or don't use the TRTLLM attention backend in flashinfer. If None, auto-detect the attention backend in flashinfer.""" - disable_flashinfer_prefill: bool = True + disable_flashinfer_prefill: bool | None = None """Whether to disable flashinfer prefill.""" disable_flashinfer_q_quantization: bool = False """If set, when using fp8 kv, do not quantize Q to fp8.""" + mla_prefill_backend: MLAPrefillBackendEnum | None = None + """MLA prefill backend to use. If None, will be selected automatically. + Valid options: FLASH_ATTN (FA3/FA4), FLASHINFER, TRTLLM_RAGGED. + This option supersedes use_trtllm_ragged_deepseek_prefill + and disable_flashinfer_prefill which are deprecated.""" + use_prefill_query_quantization: bool = False """If set, quantize query for attention in prefill.""" use_fp4_indexer_cache: bool = False """If set, use fp4 indexer cache for dsv32 family model (not support yet)""" + use_non_causal: bool = False + """Whether to use non-causal (bidirectional) attention.""" + def compute_hash(self) -> str: """ Provide a hash that uniquely identifies all the configs @@ -81,3 +94,48 @@ def validate_backend_before(cls, value: Any) -> Any: return None return AttentionBackendEnum[value.upper()] return value + + @field_validator("mla_prefill_backend", mode="before") + @classmethod + def validate_mla_prefill_backend_before(cls, value: Any) -> Any: + """Enable parsing of the `mla_prefill_backend` enum type from string.""" + if isinstance(value, str): + return MLAPrefillBackendEnum[value.upper()] + return value + + def __post_init__(self) -> None: + self._migrate_deprecated_mla_prefill_flags() + + def _migrate_deprecated_mla_prefill_flags(self) -> None: + """Migrate deprecated MLA prefill flags to mla_prefill_backend.""" + # If the new option is already set, it takes precedence + if self.mla_prefill_backend is not None: + return + + # Check for deprecated flags and migrate them. + # Only the first flag encountered sets the backend. + if self.use_cudnn_prefill: + raise ValueError( + "The cuDNN MLA prefill backend has been removed. " + "Use --attention-config.mla_prefill_backend=FLASH_ATTN or " + "FLASHINFER or TRTLLM_RAGGED instead." + ) + + if self.use_trtllm_ragged_deepseek_prefill: + if self.mla_prefill_backend is None: + self.mla_prefill_backend = MLAPrefillBackendEnum.TRTLLM_RAGGED + logger.warning_once( + "use_trtllm_ragged_deepseek_prefill is deprecated and " + "will be removed in v0.22. Use " + "--attention-config.mla_prefill_backend=TRTLLM_RAGGED " + "instead." + ) + + if self.disable_flashinfer_prefill: + if self.mla_prefill_backend is None: + self.mla_prefill_backend = MLAPrefillBackendEnum.FLASH_ATTN + logger.warning_once( + "disable_flashinfer_prefill is deprecated and will be removed " + "in v0.22. Use --attention-config.mla_prefill_backend=" + "FLASH_ATTN instead." + ) diff --git a/aphrodite/config/model.py b/aphrodite/config/model.py index db056ba48f..61b1e558af 100644 --- a/aphrodite/config/model.py +++ b/aphrodite/config/model.py @@ -515,12 +515,12 @@ def __post_init__( if dict_overrides: self._apply_dict_overrides(hf_config, dict_overrides) self.hf_text_config = get_hf_text_config(self.hf_config) + self.model_arch_config = self.get_model_arch_config() self.attention_chunk_size = getattr(self.hf_text_config, "attention_chunk_size", None) self.encoder_config = self._get_encoder_config() self.hf_image_processor_config = get_hf_image_processor_config( self.model, hf_token=self.hf_token, revision=self.revision ) - self.model_arch_config = self.get_model_arch_config() architectures = self.architectures registry = self.registry diff --git a/aphrodite/config/parallel.py b/aphrodite/config/parallel.py index 5d07853d8e..deb193749c 100644 --- a/aphrodite/config/parallel.py +++ b/aphrodite/config/parallel.py @@ -636,6 +636,26 @@ def has_unfinished_dp(dp_group: ProcessGroup, has_unfinished: bool) -> bool: aggregated_has_unfinished = bool(tensor.item()) return aggregated_has_unfinished + @staticmethod + def sync_dp_state(dp_group: ProcessGroup, has_unfinished: bool, pending_pause: bool) -> tuple[bool, bool]: + """Combined all-reduce for DP state synchronization. + Uses a single SUM all-reduce on a 2-element tensor: + [0] = 1 if this rank has unfinished work, else 0. + SUM > 0 ≡ logical OR across ranks → any rank has work. + [1] = 1 if this rank has a pending pause request, else 0. + SUM == dp_size ≡ all ranks reached pause consensus. + has_unfinished_global is true if any rank has unfinished work, + or if some ranks are waiting for a pause consensus. + Returns: + (has_unfinished_global, pause_consensus) + """ + tensor = torch.tensor([int(has_unfinished), int(pending_pause)], dtype=torch.int32, device="cpu") + torch.distributed.all_reduce(tensor, op=ReduceOp.SUM, group=dp_group) + dp_size = dp_group.size() + pause_count = tensor[1].item() + has_unfinished_global = tensor[0].item() > 0 or pause_count % dp_size != 0 + return has_unfinished_global, pause_count == dp_size + @staticmethod def sync_kv_cache_memory_size(dp_group: ProcessGroup, kv_cache_memory: int) -> int: if kv_cache_memory == -1: @@ -686,6 +706,14 @@ def compute_hash(self): "worker_extension_cls", "_api_process_count", "_api_process_rank", + # NUMA binding is per-rank host-side memory locality; it does + # not affect collective-communication semantics. When numa_bind + # is enabled with auto-detection, each DP rank stores its own + # NUMA node in numa_bind_nodes (see aphrodite/utils/numa_utils.py + # `_get_numa_node`), which would otherwise diverge the DP hash. + "numa_bind", + "numa_bind_nodes", + "numa_bind_cpus", } from aphrodite.config.utils import get_hash_factors, hash_factors diff --git a/aphrodite/config/speculative.py b/aphrodite/config/speculative.py index 3f63c0e14d..9a4d513cba 100644 --- a/aphrodite/config/speculative.py +++ b/aphrodite/config/speculative.py @@ -5,7 +5,7 @@ import copy from typing import TYPE_CHECKING, Any, Literal, get_args -from pydantic import Field, SkipValidation, model_validator +from pydantic import Field, SkipValidation, field_validator, model_validator from typing_extensions import Self from aphrodite.config.kernel import MoEBackend @@ -17,6 +17,7 @@ from aphrodite.transformers_utils.config import get_hf_text_config from aphrodite.utils.hashing import safe_hash from aphrodite.utils.import_utils import LazyLoader, has_arctic_inference +from aphrodite.v1.attention.backends.registry import AttentionBackendEnum if TYPE_CHECKING: from transformers import PretrainedConfig @@ -32,6 +33,7 @@ MTPModelTypes = Literal[ "deepseek_mtp", "mimo_mtp", + "mimo_v2_mtp", "glm4_moe_mtp", "glm4_moe_lite_mtp", "glm_ocr_mtp", @@ -101,6 +103,10 @@ class SpeculativeConfig: inherits the target model's `--moe-backend` setting. Useful when the drafter and generator require different MoE kernels (e.g. quantized generator with unquantized drafter).""" + attention_backend: AttentionBackendEnum | None = None + """Attention backend to use for the draft model. When `None`, the backend is + automatically selected. Useful when the drafter requires a different attention + backend (e.g. DFlash needs a non-causal-capable backend like FLASH_ATTN).""" max_model_len: int | None = Field(default=None, ge=1) """The maximum model length of the draft model. Used when testing the ability to skip speculation for some sequences.""" @@ -311,6 +317,48 @@ def hf_config_override(hf_config: PretrainedConfig) -> PretrainedConfig: } ) + if (arch := hf_config.architectures[0]) in ( + "MiMoV2ForCausalLM", + "MiMoV2OmniForCausalLM", + ): + from aphrodite.model_executor.models.mimo_v2_mtp import ( + _MIMO_V2_PRO_NUM_MTP_LAYERS, + ) + + mtp_arch_maps = { + "MiMoV2ForCausalLM": "MiMoV2MTPModel", + "MiMoV2OmniForCausalLM": "MiMoV2OmniMTPModel", + } + + hf_config.model_type = "mimo_v2_mtp" + # Aphrodite currently supports only the first MiMo-V2 MTP layer. + n_predict = _MIMO_V2_PRO_NUM_MTP_LAYERS + hf_config.update( + { + "num_hidden_layers": 0, + "n_predict": n_predict, + "num_nextn_predict_layers": n_predict, + "architectures": [mtp_arch_maps[arch]], + } + ) + + if hf_config.architectures[0] == "MiMoV2FlashForCausalLM": + from aphrodite.model_executor.models.mimo_v2_mtp import ( + _MIMO_V2_FLASH_NUM_MTP_LAYERS, + ) + + hf_config.model_type = "mimo_v2_mtp" + # Aphrodite currently supports only the first MiMo-V2 MTP layer. + n_predict = _MIMO_V2_FLASH_NUM_MTP_LAYERS + hf_config.update( + { + "num_hidden_layers": 0, + "n_predict": n_predict, + "num_nextn_predict_layers": n_predict, + "architectures": ["MiMoV2MTPModel"], + } + ) + if hf_config.architectures[0] == "Glm4MoeForCausalLM": hf_config.model_type = "glm4_moe_mtp" n_predict = getattr(hf_config, "num_nextn_predict_layers", None) @@ -775,6 +823,15 @@ def create_draft_parallel_config( return draft_parallel_config + @field_validator("attention_backend", mode="before") + @classmethod + def _parse_attention_backend(cls, value: Any) -> Any: + if isinstance(value, str): + if value.lower() == "auto": + return None + return AttentionBackendEnum[value.upper()] + return value + @model_validator(mode="after") def _verify_args(self) -> Self: if self.tensor_parallel_size is not None: diff --git a/aphrodite/distributed/device_communicators/all2all.py b/aphrodite/distributed/device_communicators/all2all.py index ba4cf8740d..ec7bf9daee 100644 --- a/aphrodite/distributed/device_communicators/all2all.py +++ b/aphrodite/distributed/device_communicators/all2all.py @@ -10,7 +10,6 @@ from aphrodite.distributed import get_dp_group, get_ep_group from aphrodite.forward_context import get_forward_context from aphrodite.logger import init_logger -from aphrodite.platforms import current_platform from aphrodite.utils.flashinfer import ( has_flashinfer_nvlink_one_sided, has_flashinfer_nvlink_two_sided, @@ -218,11 +217,8 @@ def _make_all2all_kwargs(self) -> dict[Any, Any]: num_rdma_bytes=num_rdma_bytes, low_latency_mode=False, num_qps_per_rank=num_qps_per_rank, + explicitly_destroy=True, ) - if not current_platform.is_rocm(): - kwargs.update( - explicitly_destroy=True, - ) return kwargs def get_handle(self, kwargs): @@ -293,13 +289,10 @@ def _make_all2all_kwargs( num_rdma_bytes=num_rdma_bytes, low_latency_mode=True, num_qps_per_rank=num_qps_per_rank, + allow_nvlink_for_low_latency_mode=True, + allow_mnnvl=envs.APHRODITE_DEEPEP_LOW_LATENCY_USE_MNNVL, + explicitly_destroy=True, ) - if not current_platform.is_rocm(): - kwargs.update( - allow_nvlink_for_low_latency_mode=True, - allow_mnnvl=envs.APHRODITE_DEEPEP_LOW_LATENCY_USE_MNNVL, - explicitly_destroy=True, - ) return kwargs def get_handle(self, kwargs): @@ -552,6 +545,8 @@ def initialize( top_k: int, num_experts: int, hidden_size: int, + dispatch_dtype_bytes_per_elem: int = 0, + dispatch_scale_bytes_per_token: int = 0, ): """Initialize the MoeAlltoAll workspace.""" if self.initialized: @@ -582,9 +577,13 @@ def initialize( ep_config = MnnvlConfig( comm_backend=CustomCommunicator(self.cpu_group), ) + if dispatch_dtype_bytes_per_elem == 0: + hidden_bytes = hidden_size // 2 + else: + hidden_bytes = hidden_size * dispatch_dtype_bytes_per_elem total_dispatch_payload_size_per_token = ( - hidden_size // 2 # nvfp4 hidden states - + hidden_size // 16 # fp8 scaling factors + hidden_bytes + + dispatch_scale_bytes_per_token + top_k * 4 # int32 topks ids + top_k * 4 # float32 topk weights ) diff --git a/aphrodite/distributed/eplb/eplb_communicator.py b/aphrodite/distributed/eplb/eplb_communicator.py index 982908d724..372cce10ce 100644 --- a/aphrodite/distributed/eplb/eplb_communicator.py +++ b/aphrodite/distributed/eplb/eplb_communicator.py @@ -11,6 +11,7 @@ from collections.abc import Sequence from datetime import timedelta +import numpy as np import torch from torch.distributed import ( P2POp, @@ -47,15 +48,25 @@ class EplbCommunicator(ABC): """Abstract EPLB communicator for expert weight transfers.""" @abstractmethod - def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None: + def add_send( + self, + tensors: list[torch.Tensor], + dst_rank: int, + expert_id: int, + ) -> None: pass @abstractmethod - def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None: + def add_recv( + self, + tensors: list[torch.Tensor], + src_rank: int, + expert_id: int, + ) -> None: pass @abstractmethod - def execute(self) -> None: + def execute(self, old_indices: np.ndarray | None = None) -> None: pass @property @@ -85,27 +96,39 @@ def __init__( self._p2p_ops: list[P2POp] = [] self._log_initialized() - def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None: - self._p2p_ops.append( - P2POp( - torch.distributed.isend, - tensor, - dst_rank, - self._ep_group, + def add_send( + self, + tensors: list[torch.Tensor], + dst_rank: int, + expert_id: int, # unused by this backend + ) -> None: + for tensor in tensors: + self._p2p_ops.append( + P2POp( + torch.distributed.isend, + tensor, + dst_rank, + self._ep_group, + ) ) - ) - def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None: - self._p2p_ops.append( - P2POp( - torch.distributed.irecv, - tensor, - src_rank, - self._ep_group, + def add_recv( + self, + tensors: list[torch.Tensor], + src_rank: int, + expert_id: int, # unused by this backend + ) -> None: + for tensor in tensors: + self._p2p_ops.append( + P2POp( + torch.distributed.irecv, + tensor, + src_rank, + self._ep_group, + ) ) - ) - def execute(self) -> None: + def execute(self, old_indices: np.ndarray | None = None) -> None: if not self._p2p_ops: return try: @@ -130,13 +153,25 @@ def __init__( self._ops: list[tuple[str, torch.Tensor, int]] = [] self._log_initialized() - def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None: - self._ops.append(("send", tensor, dst_rank)) + def add_send( + self, + tensors: list[torch.Tensor], + dst_rank: int, + expert_id: int, # unused by this backend + ) -> None: + for tensor in tensors: + self._ops.append(("send", tensor, dst_rank)) - def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None: - self._ops.append(("recv", tensor, src_rank)) + def add_recv( + self, + tensors: list[torch.Tensor], + src_rank: int, + expert_id: int, # unused by this backend + ) -> None: + for tensor in tensors: + self._ops.append(("recv", tensor, src_rank)) - def execute(self) -> None: + def execute(self, old_indices: np.ndarray | None = None) -> None: if not self._ops: return @@ -207,30 +242,29 @@ def __init__( self._cuda_stream = cuda_stream self._world_size = cpu_group.size() self._rank = cpu_group.rank() - self._send_tensors: dict[torch.dtype, list[list[torch.Tensor]]] = {} - self._recv_tensors: dict[torch.dtype, list[list[torch.Tensor]]] = {} - self._dtypes: list[torch.dtype] = [] + # expert_id -> weight tensors to pack into the send buffer. + self._expert_send_map: dict[int, list[torch.Tensor]] = {} + # src_rank -> expert_id -> weight tensors to unpack after transfer. + self._recv_map: dict[int, dict[int, list[torch.Tensor]]] = {} + self._num_local_experts: int = expert_weights[0].shape[0] self._device = expert_weights[0].device for tensor in expert_weights: assert tensor.device == self._device, ( "All local EPLB tensors are expected to be on the same device: " f"expected={self._device}, got={tensor.device}" ) - if tensor.dtype not in self._dtypes: - self._dtypes.append(tensor.dtype) config = nixl_agent_config(capture_telemetry=False) if nixl_agent_config is not None else None self._nixl_wrapper = NixlWrapper(self._make_agent_name(), config) self._nixl_memory_type = "VRAM" self._registered_desc: object | None = None self._remote_agents: dict[int, str] = {} - self._remote_send_meta: dict[int, tuple[int, int, int]] = {} + self._remote_send_meta: dict[int, tuple[int, int]] = {} self._send_buffer: torch.Tensor = torch.empty(0) self._recv_buffer: torch.Tensor = torch.empty(0) - self._peer_partition_bytes: int = 0 - self._dtype_max_bytes: dict[torch.dtype, int] = {} + self._expert_bytes: int = 0 + self._cuda_device_id = int(self._device.index or 0) - self._xfer_cache: dict[tuple[int, int, int], tuple[int, int, int]] = {} self._init_step("buffers", self._init_registered_buffers, expert_weights) self._init_step("agents", self._init_remote_agents) self._init_step("send meta", self._exchange_remote_send_meta) @@ -254,28 +288,31 @@ def _make_agent_name(self) -> str: uid = uuid.uuid4().hex[:8] return f"eplb-{self._rank}{pp_suffix}-{uid}" - def _get_peer_buckets( + def add_send( self, - bucket_map: dict[torch.dtype, list[list[torch.Tensor]]], - dtype: torch.dtype, - ) -> list[list[torch.Tensor]]: - peer_buckets = bucket_map.get(dtype) - if peer_buckets is None: - peer_buckets = [[] for _ in range(self._world_size)] - bucket_map[dtype] = peer_buckets - return peer_buckets - - def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None: + tensors: list[torch.Tensor], + dst_rank: int, + expert_id: int, + ) -> None: assert dst_rank != self._rank, ( f"EPLB communicator should not enqueue same-rank sends: rank={self._rank}, dst_rank={dst_rank}" ) - self._get_peer_buckets(self._send_tensors, tensor.dtype)[dst_rank].append(tensor) + # An expert sent to multiple peers is packed only once; skip duplicates. + if expert_id not in self._expert_send_map: + self._expert_send_map[expert_id] = tensors - def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None: + def add_recv( + self, + tensors: list[torch.Tensor], + src_rank: int, + expert_id: int, + ) -> None: assert src_rank != self._rank, ( f"EPLB communicator should not enqueue same-rank recvs: rank={self._rank}, src_rank={src_rank}" ) - self._get_peer_buckets(self._recv_tensors, tensor.dtype)[src_rank].append(tensor) + recv_experts = self._recv_map.setdefault(src_rank, {}) + if expert_id not in recv_experts: + recv_experts[expert_id] = tensors def _init_remote_agents(self) -> None: local_metadata = self._nixl_wrapper.get_agent_metadata() @@ -289,25 +326,15 @@ def _init_remote_agents(self) -> None: self._remote_agents[peer] = self._nixl_wrapper.add_remote_agent(peer_metadata) def _init_registered_buffers(self, expert_weights: Sequence[torch.Tensor]) -> None: - total_max_bytes = 0 - for dtype in self._dtypes: - max_numel = max(sum(t.numel() for t in expert_weights if t.dtype == dtype), 1) - max_bytes = max_numel * dtype.itemsize - self._dtype_max_bytes[dtype] = max_bytes - total_max_bytes += max_bytes - - self._peer_partition_bytes = total_max_bytes - - # The send buffer needs world_size partitions because remote peers - # READ from fixed offsets (rank * partition_bytes). - # This allocates world_size * partition_bytes - # which can cause OOM on large models. - # TODO(ilmarkov): shrink to const * partition_bytes and execute - # communication in multiple steps dealing with the worst case. - send_total_bytes = self._peer_partition_bytes * self._world_size - - self._send_buffer = torch.empty(send_total_bytes, device=self._device, dtype=torch.uint8) - self._recv_buffer = torch.empty(self._peer_partition_bytes, device=self._device, dtype=torch.uint8) + total_bytes = max(sum(t.nbytes for t in expert_weights), 1) + assert total_bytes % self._num_local_experts == 0, ( + f"Number of bytes in moe layer {total_bytes} is not divisible " + f"by number of local experts {self._num_local_experts}" + ) + self._expert_bytes = total_bytes // self._num_local_experts + + self._send_buffer = torch.empty(total_bytes, device=self._device, dtype=torch.uint8) + self._recv_buffer = torch.empty(total_bytes, device=self._device, dtype=torch.uint8) descs = self._nixl_wrapper.get_reg_descs([self._send_buffer, self._recv_buffer]) self._nixl_wrapper.register_memory(descs) @@ -316,12 +343,11 @@ def _init_registered_buffers(self, expert_weights: Sequence[torch.Tensor]) -> No def _exchange_remote_send_meta(self) -> None: """Exchange send-buffer metadata so each rank can build dynamic descriptors at execute time.""" - local_meta: tuple[int, int, int] = ( + local_meta: tuple[int, int] = ( self._send_buffer.data_ptr(), - self._peer_partition_bytes, self._cuda_device_id, ) - gathered_meta: list[tuple[int, int, int] | None] = [None] * self._world_size + gathered_meta: list[tuple[int, int] | None] = [None] * self._world_size torch.distributed.all_gather_object(gathered_meta, local_meta, group=self._cpu_group) for peer in self._remote_agents: @@ -331,31 +357,24 @@ def _exchange_remote_send_meta(self) -> None: @staticmethod def _pack_send_buffer( - peer_tensors: list[torch.Tensor], + in_tensors: list[torch.Tensor], send_buffer: torch.Tensor, byte_offset: int, - ) -> int: - """ - Returns the byte offset after the last written byte. - """ - for tensor in peer_tensors: + ) -> None: + for tensor in in_tensors: raw = tensor.reshape(-1).view(torch.uint8) if raw.numel() == 0: continue send_buffer[byte_offset : byte_offset + raw.numel()].copy_(raw, non_blocking=True) byte_offset += raw.numel() - return byte_offset @staticmethod def _unpack_recv_buffer( recv_buffer: torch.Tensor, - peer_tensors: list[torch.Tensor], + out_tensors: list[torch.Tensor], byte_offset: int, - ) -> int: - """ - Returns the byte offset after the last read byte. - """ - for tensor in peer_tensors: + ) -> None: + for tensor in out_tensors: num_bytes = tensor.numel() * tensor.element_size() if num_bytes == 0: continue @@ -364,19 +383,6 @@ def _unpack_recv_buffer( non_blocking=True, ) byte_offset += num_bytes - return byte_offset - - def _release_all_cached_handles(self) -> None: - """Best-effort release of every cached dlist and xfer handle.""" - for local_dlist, remote_dlist, xfer in self._xfer_cache.values(): - for release_fn, handle in ( - (self._nixl_wrapper.release_xfer_handle, xfer), - (self._nixl_wrapper.release_dlist_handle, local_dlist), - (self._nixl_wrapper.release_dlist_handle, remote_dlist), - ): - with contextlib.suppress(Exception): - release_fn(handle) - self._xfer_cache.clear() def _wait_for_all_transfers(self, handles: list[int]) -> None: pending = set(handles) @@ -394,78 +400,59 @@ def _wait_for_all_transfers(self, handles: list[int]) -> None: if pending: time.sleep(0.0005) - def _get_or_create_xfer(self, src: int, total_bytes: int, recv_offset: int) -> int: - """Return a cached xfer handle or create and cache a new one.""" - key = (src, total_bytes, recv_offset) - cached = self._xfer_cache.get(key) - if cached is not None: - return cached[2] - - recv_base = self._recv_buffer.data_ptr() - local_desc = self._nixl_wrapper.get_xfer_descs( - [ - ( - recv_base + recv_offset, - total_bytes, - self._cuda_device_id, - ) - ], - self._nixl_memory_type, - ) + def _create_peer_xfer( + self, + src: int, + local_descs: list[tuple[int, int, int]], + remote_descs: list[tuple[int, int, int]], + ) -> tuple[int, int, int]: + """Create a batched xfer for multiple descriptors from one peer. + + Each element in *local_descs* / *remote_descs* is an + ``(address, size, device_id)`` tuple. + + Returns ``(local_dlist, remote_dlist, xfer_handle)``. + """ + local_desc = self._nixl_wrapper.get_xfer_descs(local_descs, self._nixl_memory_type) local_handle = self._nixl_wrapper.prep_xfer_dlist( "NIXL_INIT_AGENT", local_desc, ) - remote_base, remote_part_bytes, remote_dev = self._remote_send_meta[src] - agent_name = self._remote_agents[src] - remote_desc = self._nixl_wrapper.get_xfer_descs( - [ - ( - remote_base + self._rank * remote_part_bytes, - total_bytes, - remote_dev, - ) - ], - self._nixl_memory_type, - ) + remote_desc = self._nixl_wrapper.get_xfer_descs(remote_descs, self._nixl_memory_type) remote_handle = self._nixl_wrapper.prep_xfer_dlist( - agent_name, + self._remote_agents[src], remote_desc, ) + indices = list(range(len(local_descs))) xfer_handle = self._nixl_wrapper.make_prepped_xfer( "READ", local_handle, - [0], + indices, remote_handle, - [0], + indices, ) - self._xfer_cache[key] = (local_handle, remote_handle, xfer_handle) - return xfer_handle + return (local_handle, remote_handle, xfer_handle) + + def execute(self, old_indices: np.ndarray | None = None) -> None: + assert old_indices is not None, "NixlEplbCommunicator.execute requires old_indices" - def execute(self) -> None: - xfer_handles: list[int] = [] + xfer_entries: list[tuple[int, int, int]] = [] try: - # Phase 1: pack send buffers. + n = self._num_local_experts + rank_experts = old_indices[: self._world_size * n].reshape(self._world_size, n) + # Build expert_id -> send slot mapping per rank. + expert_to_send_slot: list[dict[int, int]] = [ + {int(eid): i for i, eid in enumerate(row) if eid != -1} for row in rank_experts + ] + + # Phase 1: pack each expert at its slot offset in the send buffer. with torch.cuda.stream(self._cuda_stream): - for dst in range(self._world_size): - byte_offset = dst * self._peer_partition_bytes - for dtype in self._dtypes: - peer_tensors = self._send_tensors.get(dtype, [[] for _ in range(self._world_size)])[dst] - actual_bytes = sum(t.numel() * t.element_size() for t in peer_tensors) - if actual_bytes > self._dtype_max_bytes[dtype]: - raise RuntimeError( - "NIXL EPLB send overflow for dtype " - f"{dtype}: peer={dst}, " - f"required={actual_bytes}, " - f"capacity={self._dtype_max_bytes[dtype]}" - ) - byte_offset = self._pack_send_buffer( - peer_tensors, - self._send_buffer, - byte_offset, - ) + for expert_id, tensors in self._expert_send_map.items(): + slot = expert_to_send_slot[self._rank][expert_id] + byte_offset = slot * self._expert_bytes + self._pack_send_buffer(tensors, self._send_buffer, byte_offset) # Ensure all packed data is visible in device memory before pulls. if self._cuda_stream is not None: @@ -480,50 +467,61 @@ def execute(self) -> None: timeout=timedelta(minutes=5), ) - # Phase 2: look up or create descriptors and issue all READs. - # Data from all peers is packed sequentially into the single - # partition-sized recv buffer at running offsets. - recv_offsets: dict[int, int] = {} + # Phase 2: issue one batched READ per peer. + recv_offsets: dict[tuple[int, int], int] = {} recv_offset = 0 + recv_base = self._recv_buffer.data_ptr() for src in range(self._world_size): if src == self._rank: continue - actual_total_bytes = 0 - for dtype in self._dtypes: - peer_tensors = self._recv_tensors.get(dtype, [[] for _ in range(self._world_size)])[src] - actual_total_bytes += sum(t.numel() * t.element_size() for t in peer_tensors) - if actual_total_bytes == 0: + recv_experts = self._recv_map.get(src) + if not recv_experts: continue + expert_ids = list(recv_experts.keys()) + remote_base, remote_dev = self._remote_send_meta[src] + local_descs: list[tuple[int, int, int]] = [] + remote_descs: list[tuple[int, int, int]] = [] + for expert_id in expert_ids: + slot = expert_to_send_slot[src][expert_id] + remote_off = slot * self._expert_bytes + recv_offsets[(src, expert_id)] = recv_offset + local_descs.append( + ( + recv_base + recv_offset, + self._expert_bytes, + self._cuda_device_id, + ) + ) + remote_descs.append((remote_base + remote_off, self._expert_bytes, remote_dev)) + recv_offset += self._expert_bytes + assert recv_offset <= self._recv_buffer.nbytes + local_h, remote_h, xfer_h = self._create_peer_xfer(src, local_descs, remote_descs) + self._nixl_wrapper.transfer(xfer_h) + xfer_entries.append((local_h, remote_h, xfer_h)) - recv_offsets[src] = recv_offset - xfer_handle = self._get_or_create_xfer(src, actual_total_bytes, recv_offset) - self._nixl_wrapper.transfer(xfer_handle) - xfer_handles.append(xfer_handle) - recv_offset += actual_total_bytes - - # Phase 3: single wait for all in-flight transfers, then unpack. - self._wait_for_all_transfers(xfer_handles) + # Phase 3: wait for all in-flight transfers, then unpack. + self._wait_for_all_transfers([x[2] for x in xfer_entries]) with torch.cuda.stream(self._cuda_stream): - for src, offset in recv_offsets.items(): - byte_offset = offset - for dtype in self._dtypes: - peer_tensors = self._recv_tensors.get(dtype, [[] for _ in range(self._world_size)])[src] - byte_offset = self._unpack_recv_buffer( - self._recv_buffer, - peer_tensors, - byte_offset, - ) - except Exception: - self._release_all_cached_handles() - raise + for (src, expert_id), offset in recv_offsets.items(): + self._unpack_recv_buffer( + self._recv_buffer, + self._recv_map[src][expert_id], + offset, + ) finally: - self._send_tensors.clear() - self._recv_tensors.clear() + for local_h, remote_h, xfer_h in xfer_entries: + with contextlib.suppress(Exception): + self._nixl_wrapper.release_xfer_handle(xfer_h) + with contextlib.suppress(Exception): + self._nixl_wrapper.release_dlist_handle(local_h) + with contextlib.suppress(Exception): + self._nixl_wrapper.release_dlist_handle(remote_h) + self._expert_send_map.clear() + self._recv_map.clear() def __del__(self) -> None: try: - self._release_all_cached_handles() if self._registered_desc is not None: self._nixl_wrapper.deregister_memory(self._registered_desc) self._registered_desc = None @@ -552,15 +550,27 @@ def _ensure_group_started(self) -> None: self._pynccl_comm.group_start() self._group_started = True - def add_send(self, tensor: torch.Tensor, dst_rank: int) -> None: + def add_send( + self, + tensors: list[torch.Tensor], + dst_rank: int, + expert_id: int, # unused by this backend + ) -> None: self._ensure_group_started() - self._pynccl_comm.send(tensor, dst_rank, stream=self._cuda_stream) + for tensor in tensors: + self._pynccl_comm.send(tensor, dst_rank, stream=self._cuda_stream) - def add_recv(self, tensor: torch.Tensor, src_rank: int) -> None: + def add_recv( + self, + tensors: list[torch.Tensor], + src_rank: int, + expert_id: int, # unused by this backend + ) -> None: self._ensure_group_started() - self._pynccl_comm.recv(tensor, src_rank, stream=self._cuda_stream) + for tensor in tensors: + self._pynccl_comm.recv(tensor, src_rank, stream=self._cuda_stream) - def execute(self) -> None: + def execute(self, old_indices: np.ndarray | None = None) -> None: if self._group_started: self._pynccl_comm.group_end() self._group_started = False diff --git a/aphrodite/distributed/eplb/rebalance_execute.py b/aphrodite/distributed/eplb/rebalance_execute.py index 14c58e460f..2313ac2988 100644 --- a/aphrodite/distributed/eplb/rebalance_execute.py +++ b/aphrodite/distributed/eplb/rebalance_execute.py @@ -280,9 +280,9 @@ def move_to_buffer( recver_pos = remainder_start + sender_pos if recver_pos < len(ranks_to_recv): recv_ranks.append(ranks_to_recv[recver_pos]) + expert_tensors = [w[src] for w in expert_weights] for dst in recv_ranks: - for w in expert_weights: - communicator.add_send(w[src], dst) + communicator.add_send(expert_tensors, dst, expert_id=int(expert)) # 3. Post recvs if recv_count > 0: @@ -311,11 +311,14 @@ def move_to_buffer( src = ranks_to_send[recver_pos // num_dst_per_sender] else: src = ranks_to_send[recver_pos - remainder_start] - for b in expert_weights_buffers: - communicator.add_recv(b[dst], src) + communicator.add_recv( + [b[dst] for b in expert_weights_buffers], + src, + expert_id=int(expert), + ) # 4. Execute the P2P operations. The real communication happens here. - communicator.execute() + communicator.execute(old_indices=old_indices) # wait for the communication to finish return TransferMetadata( is_unchanged=is_unchanged, diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/multi_connector.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/multi_connector.py index 258efb72fa..00da7d07ff 100644 --- a/aphrodite/distributed/kv_transfer/kv_connector/v1/multi_connector.py +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/multi_connector.py @@ -1,9 +1,9 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import copy -from collections.abc import Iterable +from collections.abc import Callable, Iterable from dataclasses import dataclass -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, cast import torch @@ -18,6 +18,8 @@ KVConnectorMetadata, KVConnectorRole, KVConnectorWorkerMetadata, + SupportsHMA, + supports_hma, ) from aphrodite.distributed.kv_transfer.kv_connector.v1.metrics import ( KVConnectorPromMetrics, @@ -121,7 +123,7 @@ def observe(self, transfer_stats_data: dict[str, Any], engine_idx: int = 0): self._prom_metrics[connector_id].observe(stats_data["data"], engine_idx) -class MultiConnector(KVConnectorBase_V1): +class MultiConnector(KVConnectorBase_V1, SupportsHMA): """ A wrapper for using multiple KVConnectors at the same time. @@ -160,6 +162,11 @@ def __init__( self._connectors.append(connector_cls(temp_config, role, kv_cache_config)) self._ktc_kv_transfer_config.append(temp_config.kv_transfer_config) + self._all_support_hma = all(supports_hma(c) for c in self._connectors) + assert aphrodite_config.scheduler_config.disable_hybrid_kv_cache_manager or self._all_support_hma, ( + "HMA should not be enabled unless all sub-connectors support it" + ) + # A mapping from request id to the index of the connector chosen to # load the request from (if any). self._requests_to_connector: dict[str, int] = {} @@ -406,15 +413,15 @@ def set_xfer_handshake_metadata(self, metadata: dict[int, KVConnectorHandshakeMe for c in self._connectors: c.set_xfer_handshake_metadata(metadata) - def request_finished( + def _aggregate_request_finished( self, request: "Request", - blocks: list[int], + per_connector_fn: Callable[[KVConnectorBase_V1], tuple[bool, dict[str, Any] | None]], ) -> tuple[bool, dict[str, Any] | None]: async_saves = 0 kv_txfer_params = None for c in self._connectors: - async_save, txfer_params = c.request_finished(request, blocks) + async_save, txfer_params = per_connector_fn(c) if async_save: async_saves += 1 if txfer_params is not None: @@ -426,11 +433,34 @@ def request_finished( if async_saves > 1: self._extra_async_saves[request.request_id] = async_saves - 1 - # Clean up other state for this request. self._requests_to_connector.pop(request.request_id, None) return async_saves > 0, kv_txfer_params + def request_finished( + self, + request: "Request", + blocks: list[int], + ) -> tuple[bool, dict[str, Any] | None]: + return self._aggregate_request_finished( + request, + lambda c: c.request_finished(request, blocks), + ) + + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + if not self._all_support_hma: + assert len(block_ids) == 1, "HMA with multiple kv_cache_groups requires all sub-connectors to support HMA" + return self.request_finished(request, block_ids[0]) + + return self._aggregate_request_finished( + request, + lambda c: cast(SupportsHMA, c).request_finished_all_groups(request, block_ids), + ) + def take_events(self) -> Iterable["KVCacheEvent"]: for c in self._connectors: yield from c.take_events() diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py index a890fa8a04..3e6057bf89 100644 --- a/aphrodite/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/nixl/scheduler.py @@ -1,5 +1,6 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project +# SPDX-FileCopyrightText: Copyright contributors to the aphrodite project """Scheduler-side logic for the NIXL connector.""" import threading @@ -109,6 +110,25 @@ def __init__( cdiv(n_tokens, block_size) + 1 if n_tokens else 0 for n_tokens, block_size in sw_sizes_tokens ] + # Threshold to decide whether to compute kv cache locally + # or pull from a remote node: minimum number of remote + # tokens to amortize the xfer latencies + self.kv_recompute_threshold: int = int( + aphrodite_config.kv_transfer_config.get_from_extra_config("kv_recompute_threshold", 64) + ) + + # Bi-directional KV transfer feature supports KV block + # transfers from D node to P node + self.is_bidirectional_kv_xfer_enabled = aphrodite_config.kv_transfer_config.get_from_extra_config( + "bidirectional_kv_xfer", False + ) + + if self.is_bidirectional_kv_xfer_enabled and self.kv_recompute_threshold > 0: + logger.info( + "Bidirectional KV transfer is enabled and the kv recompute threshold is set to %d tokens", + self.kv_recompute_threshold, + ) + def shutdown(self): self._stop_event.set() if self._nixl_handshake_listener_t is not None: @@ -276,6 +296,39 @@ def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: in if params is not None and params.get("do_remote_decode") and self._has_mamba: self._truncate_mamba_request_for_prefill(request) + if ( + params is not None + and params.get("do_remote_decode") + and params.get("remote_block_ids") + and all( + p in params + for p in ( + "remote_engine_id", + "remote_request_id", + "remote_host", + "remote_port", + ) + ) + ): + # Decode node has kv blocks for part of prefill request, so, provide them + # as an external token count to scheduler. + # The tokens will be loaded if not already present + # in the prefill node local cache + remote_num_tokens = params.get("remote_num_tokens") or 0 + count = min(remote_num_tokens, request.num_prompt_tokens) - num_computed_tokens + if count > 0: + # Check kv_recompute_threshold: skip pull if + # remote tokens are below the threshold. + if self.kv_recompute_threshold > 0 and count < self.kv_recompute_threshold: + logger.debug( + "Skipping remote pull for %s: %d remote tokens < threshold %d", + request.request_id, + count, + self.kv_recompute_threshold, + ) + return 0, False + return count, True + # No remote prefill for this request. return 0, False @@ -290,13 +343,19 @@ def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", if not params: return - if params.get("do_remote_decode"): + if params.get("do_remote_decode") or ( + params.get("do_remote_prefill") and self.is_bidirectional_kv_xfer_enabled + ): self._reqs_in_batch.add(request.request_id) if self.use_host_buffer and params.get("do_remote_decode"): # NOTE: when accelerator is not directly supported by Nixl, # prefilled blocks need to be saved to host memory before transfer. self._reqs_need_save[request.request_id] = request - elif params.get("do_remote_prefill"): + elif params.get("do_remote_prefill") or ( + params.get("do_remote_decode") + and self.is_bidirectional_kv_xfer_enabled + and not params.get("_remote_blocks_processed") + ): if params.get("remote_block_ids"): if all( p in params @@ -308,8 +367,8 @@ def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", ) ): # If remote_blocks and num_external_tokens = 0, we have - # a full prefix cache hit on the D worker. We need to call - # send_notif in _read_blocks to free the memory on the P. + # a full prefix cache hit on the local node. We need to call + # send_notif in _read_blocks to free the memory on the remote node. unhashed_local_block_ids: BlockIds = ( blocks.get_unhashed_block_ids_all_groups() if num_external_tokens > 0 else () @@ -332,6 +391,7 @@ def update_state_after_alloc(self, request: "Request", blocks: "KVCacheBlocks", assert num_external_tokens == 0 # Only trigger 1 KV transfer per request. params["do_remote_prefill"] = False + params["_remote_blocks_processed"] = True def _build_save_meta( self, @@ -417,6 +477,9 @@ def request_finished( if not params: return False, None + is_p_node = bool(params.get("do_remote_decode")) + is_d_node = not is_p_node + if params.get("do_remote_prefill"): # If do_remote_prefill is still True when the request is finished, # update_state_after_alloc must not have been called (the request @@ -428,9 +491,13 @@ def request_finished( params["do_remote_prefill"] = False return False, None - if not params.get("do_remote_decode"): + if is_d_node and not self.is_bidirectional_kv_xfer_enabled: return False, None - if request.status != RequestStatus.FINISHED_LENGTH_CAPPED: + + if request.status not in ( + RequestStatus.FINISHED_LENGTH_CAPPED, + RequestStatus.FINISHED_STOPPED, + ): # Also include the case of a P/D Prefill request with immediate # block free (eg abort). Stop tracking this request. self._reqs_not_processed.add(request.request_id) @@ -441,6 +508,7 @@ def request_finished( # TODO: check whether block_ids actually ever be 0. If not we could # remove the conditional below delay_free_blocks = any(len(group) > 0 for group in block_ids) + remote_num_tokens = 0 if delay_free_blocks: # Prefill request on remote. It will be read from D upon completion @@ -456,13 +524,16 @@ def request_finished( # Here we "unpad" blocks to send the actual remote blocks to be read. block_ids = self.get_sw_clipped_blocks(block_ids) + remote_num_tokens = request.num_computed_tokens + return delay_free_blocks, dict( - do_remote_prefill=True, - do_remote_decode=False, + do_remote_prefill=is_p_node, + do_remote_decode=is_d_node, remote_block_ids=block_ids, remote_engine_id=self.engine_id, remote_request_id=request.request_id, remote_host=self.side_channel_host, remote_port=self.side_channel_port, tp_size=self.aphrodite_config.parallel_config.tensor_parallel_size, + remote_num_tokens=remote_num_tokens, ) diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/nixl/worker.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/nixl/worker.py index f2e7c5f98b..064b5a0022 100644 --- a/aphrodite/distributed/kv_transfer/kv_connector/v1/nixl/worker.py +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/nixl/worker.py @@ -1746,7 +1746,7 @@ def _read_blocks_for_req(self, req_id: str, meta: ReqMeta): if self.use_mla and tp_ratio < 0: # ..but we still need to notify the other remote ranks that we # have the blocks we need so they can update the request state. - notif_id = f"{req_id}:{self.world_size}".encode() + notif_id = f"{meta.remote.request_id}:{self.world_size}".encode() remote_agents = self._remote_agents[meta.remote.engine_id] for rank_to_notify, agent in remote_agents.items(): if rank_to_notify != remote_rank: diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/offloading/common.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/offloading/common.py index 3e8567a37b..601f22c44b 100644 --- a/aphrodite/distributed/kv_transfer/kv_connector/v1/offloading/common.py +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/offloading/common.py @@ -1,15 +1,56 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from dataclasses import dataclass +from dataclasses import dataclass, field -from aphrodite.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata +from aphrodite.distributed.kv_transfer.kv_connector.v1.base import ( + KVConnectorMetadata, + KVConnectorWorkerMetadata, +) from aphrodite.v1.kv_offload.worker.worker import TransferSpec ReqId = str +@dataclass +class TransferJob: + """A transfer job bundling request context with transfer spec. + Used for both loads and stores, keyed by scheduler-assigned job ID. + The worker reports the job ID back when the transfer finishes, + and the scheduler processes the completion. + """ + + req_id: ReqId + transfer_spec: TransferSpec + + @dataclass class OffloadingConnectorMetadata(KVConnectorMetadata): - reqs_to_load: dict[ReqId, TransferSpec] - reqs_to_store: dict[ReqId, TransferSpec] - reqs_to_flush: set[str] | None = None + # Keyed by scheduler-assigned job IDs. + load_jobs: dict[int, TransferJob] + store_jobs: dict[int, TransferJob] + jobs_to_flush: set[int] | None = None + + +@dataclass +class OffloadingWorkerMetadata(KVConnectorWorkerMetadata): + """Worker -> Scheduler metadata for completed transfer jobs. + Each worker reports {job_id: 1} for newly completed transfer jobs + (load or store). aggregate() sums counts across workers within a step. + The scheduler accumulates across steps and processes + a transfer completion only when count reaches num_workers. + """ + + completed_jobs: dict[int, int] = field(default_factory=dict) + + def mark_completed(self, job_id: int) -> None: + """Record a transfer job completion from this worker.""" + self.completed_jobs[job_id] = 1 + + def aggregate(self, other: "KVConnectorWorkerMetadata") -> "KVConnectorWorkerMetadata": + assert isinstance(other, OffloadingWorkerMetadata) + + merged = dict(self.completed_jobs) + for job_id, v in other.completed_jobs.items(): + merged[job_id] = merged.get(job_id, 0) + v + + return OffloadingWorkerMetadata(completed_jobs=merged) diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py index bc29bc5959..f2cdfaa642 100644 --- a/aphrodite/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/offloading/scheduler.py @@ -1,6 +1,5 @@ # SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project -from collections import defaultdict from collections.abc import Iterable, Sequence from dataclasses import dataclass, field from itertools import islice @@ -11,48 +10,95 @@ from aphrodite.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from aphrodite.distributed.kv_transfer.kv_connector.v1.offloading.common import ( OffloadingConnectorMetadata, + OffloadingWorkerMetadata, ReqId, + TransferJob, ) from aphrodite.logger import init_logger from aphrodite.utils.math_utils import cdiv from aphrodite.v1.core.kv_cache_manager import KVCacheBlocks from aphrodite.v1.core.sched.output import SchedulerOutput -from aphrodite.v1.kv_offload.abstract import ( +from aphrodite.v1.kv_cache_interface import ( + FullAttentionSpec, + KVCacheSpec, + MambaSpec, + SlidingWindowSpec, +) +from aphrodite.v1.kv_offload.base import ( + GPULoadStoreSpec, OffloadingManager, + OffloadingSpec, OffloadKey, ReqContext, get_offload_block_hash, make_offload_key, ) -from aphrodite.v1.kv_offload.mediums import GPULoadStoreSpec -from aphrodite.v1.kv_offload.spec import OffloadingSpec -from aphrodite.v1.kv_offload.worker.worker import TransferSpec from aphrodite.v1.outputs import KVConnectorOutput from aphrodite.v1.request import Request logger = init_logger(__name__) +@dataclass(slots=True) +class TransferJobStatus: + """Tracks scheduler-side state for a single transfer job.""" + + req_id: ReqId + # Number of workers still pending. Starts at num_workers, + # decremented as each worker reports completion. Job is done at 0. + pending_count: int + # Offload keys this job covers; passed to manager.complete_*(). + keys: set[OffloadKey] + is_store: bool + # Store src block IDs whose ref_cnt protects them while the request + # runs. Only registered in _block_id_to_pending_jobs on request_finished. + non_sliding_window_block_ids: list[int] | None = None + # Store src block IDs that may be freed before the request finishes. + # Registered in _block_id_to_pending_jobs at store creation time. + sliding_window_block_ids: list[int] | None = None + + class GroupOffloadConfig(NamedTuple): group_idx: int gpu_block_size: int offloaded_block_size: int hash_block_size_factor: int + # None below means full attention + sliding_window_size_in_blocks: int | None + + +def get_sliding_window_size_in_blocks(kv_cache_spec: KVCacheSpec, offloaded_block_size: int) -> int | None: + if isinstance(kv_cache_spec, SlidingWindowSpec): + assert kv_cache_spec.sliding_window > 0 + return cdiv(kv_cache_spec.sliding_window, offloaded_block_size) + + if isinstance(kv_cache_spec, MambaSpec): + # Mamba depends on a single state + return 1 + + assert isinstance(kv_cache_spec, FullAttentionSpec) + return None class SchedulerOffloadConfig(NamedTuple): kv_group_configs: tuple[GroupOffloadConfig, ...] block_size_factor: int + num_workers: int @classmethod def from_spec(cls, spec: OffloadingSpec) -> "SchedulerOffloadConfig": return cls( + num_workers=spec.aphrodite_config.parallel_config.world_size, kv_group_configs=tuple( GroupOffloadConfig( group_idx=idx, gpu_block_size=gpu_block_size, offloaded_block_size=gpu_block_size * spec.block_size_factor, hash_block_size_factor=((gpu_block_size * spec.block_size_factor) // spec.hash_block_size), + sliding_window_size_in_blocks=get_sliding_window_size_in_blocks( + spec.kv_cache_config.kv_cache_groups[idx].kv_cache_spec, + gpu_block_size * spec.block_size_factor, + ), ) for idx, gpu_block_size in enumerate(spec.gpu_block_size) ), @@ -66,6 +112,9 @@ class RequestGroupState: block_ids: list[int] = field(default_factory=list) # index of next block (of size offloaded_block_size) to offload next_stored_block_idx: int = 0 + # number of offloaded blocks hit (including GPU prefix cache) + # when the request first started + num_hit_blocks: int = 0 @dataclass(slots=True) @@ -76,6 +125,9 @@ class RequestOffloadState: req_context: ReqContext = field(init=False) # number of hits in the GPU cache num_locally_computed_tokens: int = 0 + # In-flight job IDs. Per the connector's invariant, at any given time + # this contains either a single load job, or one or more store jobs. + transfer_jobs: set[int] = field(default_factory=set) def __post_init__(self) -> None: self.group_states = tuple(RequestGroupState() for _ in self.config.kv_group_configs) @@ -106,6 +158,10 @@ def advance_stored_idx(self, num_offloadable_tokens: int) -> None: num_blocks = num_offloadable_tokens // group_config.offloaded_block_size group_state.next_stored_block_idx = num_blocks + def update_num_hit_blocks(self, num_cached_tokens: int) -> None: + for group_config, group_state in zip(self.config.kv_group_configs, self.group_states): + group_state.num_hit_blocks = num_cached_tokens // group_config.offloaded_block_size + class OffloadingConnectorScheduler: """Implementation of Scheduler side methods""" @@ -114,28 +170,61 @@ def __init__(self, spec: OffloadingSpec): self.config = SchedulerOffloadConfig.from_spec(spec) self.manager: OffloadingManager = spec.get_manager() - attention_groups: list[int] = [] - for idx, _ in enumerate(spec.kv_cache_config.kv_cache_groups): - # currently treat all groups as full attention - attention_groups.append(idx) + full_attention_groups: list[int] = [] + sliding_window_groups: list[int] = [] + for group_config in self.config.kv_group_configs: + if group_config.sliding_window_size_in_blocks is None: + full_attention_groups.append(group_config.group_idx) + else: + sliding_window_groups.append(group_config.group_idx) + + # sort sliding window groups by window size in decreasing order + def _sliding_window_sort_key(i: int) -> int: + val = self.config.kv_group_configs[i].sliding_window_size_in_blocks + assert val is not None + return val + + sliding_window_groups.sort(key=_sliding_window_sort_key, reverse=True) - self.lookup_groups = attention_groups + # used by _lookup + self._sliding_window_groups: tuple[int, ...] = tuple(sliding_window_groups) + self._lookup_groups = tuple(full_attention_groups) + self._sliding_window_groups self._req_status: dict[ReqId, RequestOffloadState] = {} - # requests to load for the current scheduler step - self._reqs_to_load: dict[ReqId, TransferSpec] = {} + self._current_batch_load_jobs: dict[int, TransferJob] = {} + self._current_batch_jobs_to_flush: set[int] = set() # if GPU prefix caching is enabled, # track loaded blocks to avoid redundant loads self._blocks_being_loaded: set[OffloadKey] | None = ( set() if spec.aphrodite_config.cache_config.enable_prefix_caching else None ) - # request ID -> set(offload keys being stored/loaded) - self._reqs_being_stored = defaultdict[ReqId, set[OffloadKey]](set) - self._reqs_being_loaded = defaultdict[ReqId, set[OffloadKey]](set) + # Job ID counter shared by loads and stores. + self._job_counter: int = 0 + self._jobs: dict[int, TransferJobStatus] = {} + + # block_id -> pending store job_ids. Used to track jobs that needs + # flushing in case a block is re-allocated by the KV cache manager. + # Populated only for finished requests (running-request blocks are + # protected by their ref_cnt) and for sliding window blocks (which can + # be freed before a request finishes). + self._block_id_to_pending_jobs: dict[int, set[int]] = {} + + def _generate_job_id(self) -> int: + job_id = self._job_counter + self._job_counter += 1 + return job_id + + def _remove_pending_job(self, job_id: int, block_ids: list[int] | None) -> None: + for bid in block_ids or (): + pending = self._block_id_to_pending_jobs[bid] + pending.remove(job_id) + if not pending: + del self._block_id_to_pending_jobs[bid] def _maximal_prefix_lookup(self, keys: Iterable[OffloadKey], req_context: ReqContext) -> int | None: - """Find the length of the maximal prefix of offloaded blocks.""" + """Return the number of consecutive offloaded blocks from the start, + or None if the backend deferred a lookup.""" hit_count = 0 defer_lookup = False for key in keys: @@ -156,8 +245,9 @@ def _sliding_window_lookup( sliding_window_size: int, req_context: ReqContext, ) -> int | None: - """Find the maximal ending position of consecutive offloaded blocks - within a sliding window.""" + """Return the end index (in `keys`) of the last run of + `sliding_window_size` consecutive hits, scanning from the end. + Returns 0 on miss, None if the backend deferred a lookup.""" defer_lookup = False consecutive_hits = 0 for idx in range(len(keys) - 1, -1, -1): @@ -175,6 +265,137 @@ def _sliding_window_lookup( return idx + sliding_window_size if not defer_lookup else None return consecutive_hits if not defer_lookup else None + def _touch(self, req_status: RequestOffloadState): + for group_config, group_state in zip(self.config.kv_group_configs, req_status.group_states): + if group_config.sliding_window_size_in_blocks is None: + self.manager.touch(group_state.offload_keys) + else: + # we aim to keep just blocks that are necessary to hit + # the original request (+ decoded blocks) + blocks_to_skip = max( + 0, + group_state.num_hit_blocks - group_config.sliding_window_size_in_blocks, + ) + self.manager.touch(group_state.offload_keys[blocks_to_skip:]) + + def _lookup(self, req_status: RequestOffloadState) -> int | None: + """ + Find how many tokens beyond num_locally_computed_tokens can be loaded. + + Iterates full-attention groups first (prefix lookup), then sliding-window + groups (suffix lookup). Each group may tighten max_hit_size_tokens, which + can invalidate an earlier group's result, so the loop re-runs when that + happens until num_hit_tokens converges. + """ + num_computed_tokens = req_status.num_locally_computed_tokens + max_hit_size_tokens: int = req_status.req.num_tokens + if self._sliding_window_groups: + # the last prompt token has to be recomputed to get the logprobs + # for sliding window attention, we must reduce by 1 to make sure + # we still have a hit after reduction + max_hit_size_tokens -= 1 + num_hit_tokens: int = 0 + defer_lookup = False + lookup_groups = self._lookup_groups + while lookup_groups: + looked_up_sliding_window: bool = False + groups_iter = iter(lookup_groups) + lookup_groups = () + for group_idx in groups_iter: + group_config: GroupOffloadConfig = self.config.kv_group_configs[group_idx] + group_state: RequestGroupState = req_status.group_states[group_idx] + offloaded_block_size = group_config.offloaded_block_size + offload_keys = group_state.offload_keys + + assert len(offload_keys) >= req_status.req.num_tokens // offloaded_block_size + + # Constrain to block-aligned boundary for this group + max_hit_size_tokens = min(max_hit_size_tokens, len(offload_keys) * offloaded_block_size) + if max_hit_size_tokens - num_computed_tokens < offloaded_block_size: + # we can only load less than a block, better skip + return 0 + + num_blocks = min(cdiv(max_hit_size_tokens, offloaded_block_size), len(offload_keys)) + start_block_idx = num_computed_tokens // offloaded_block_size + offload_keys = offload_keys[start_block_idx:num_blocks] + sliding_window_size_in_blocks = group_config.sliding_window_size_in_blocks + + # end index (in the sliced offload_keys) up to which we + # have backend-confirmed hits + num_hit_blocks: int | None + if sliding_window_size_in_blocks is None: + num_hit_blocks = self._maximal_prefix_lookup(offload_keys, req_status.req_context) + else: + num_hit_blocks = self._sliding_window_lookup( + offload_keys, + sliding_window_size_in_blocks, + req_status.req_context, + ) + if num_hit_blocks == 0: + return 0 + + if num_hit_blocks is None: + defer_lookup = True + else: + max_hit_size_tokens = min( + max_hit_size_tokens, + offloaded_block_size * (start_block_idx + num_hit_blocks), + ) + + new_num_hit_tokens = max_hit_size_tokens - num_computed_tokens + if new_num_hit_tokens < offloaded_block_size: + # we can only load less than a block, better skip + return 0 + + if new_num_hit_tokens < num_hit_tokens: + if defer_lookup: + # make another iteration on all groups to check + # if we still need to defer lookup + defer_lookup = False + lookup_groups = self._lookup_groups + elif looked_up_sliding_window and not lookup_groups: + # we need another iteration to confirm previously looked up + # sliding window works with the new_num_hit_tokens + lookup_groups = self._sliding_window_groups + + looked_up_sliding_window |= sliding_window_size_in_blocks is not None + num_hit_tokens = new_num_hit_tokens + + if defer_lookup: + logger.debug( + "Offloading manager delayed request %s as backend requested", + req_status.req.request_id, + ) + return None + + # possibly delay request if any of the hit blocks is already being loaded + if self._blocks_being_loaded: + for group_config, group_state in zip(self.config.kv_group_configs, req_status.group_states): + offloaded_block_size = group_config.offloaded_block_size + sliding_window_size_in_blocks = group_config.sliding_window_size_in_blocks + offload_keys = group_state.offload_keys + num_blocks = cdiv(num_computed_tokens + num_hit_tokens, offloaded_block_size) + start_block_idx = num_computed_tokens // offloaded_block_size + offload_keys = offload_keys[start_block_idx:num_blocks] + if sliding_window_size_in_blocks is not None: + offload_keys = offload_keys[-sliding_window_size_in_blocks:] + if any(key in self._blocks_being_loaded for key in offload_keys): + # hit blocks are being loaded, delay request + logger.debug( + "Delaying request %s since some of its blocks are already being loaded", + req_status.req.request_id, + ) + return None + + logger.debug( + "Request %s hit %s offloaded tokens after %s GPU hit tokens", + req_status.req.request_id, + num_hit_tokens, + num_computed_tokens, + ) + + return num_hit_tokens + def get_num_new_matched_tokens(self, request: Request, num_computed_tokens: int) -> tuple[int | None, bool]: """ Get number of new tokens that can be loaded beyond the @@ -195,89 +416,26 @@ def get_num_new_matched_tokens(self, request: Request, num_computed_tokens: int) - `True` if tokens will be loaded asynchronously (between scheduler steps). """ + is_new_request = False if req_status := self._req_status.get(request.request_id): # make sure block IDs are cleared for group_state in req_status.group_states: group_state.block_ids.clear() else: + is_new_request = True req_status = RequestOffloadState(config=self.config, req=request) self._req_status[request.request_id] = req_status req_status.update_offload_keys() req_status.num_locally_computed_tokens = num_computed_tokens - for gs in req_status.group_states: - self.manager.touch(gs.offload_keys) - - # Start with the full request size as the maximum loadable - max_hit_size_tokens: int = req_status.req.num_tokens - num_hit_tokens: int = 0 - defer_lookup = False - delay_request = False - for group_idx in self.lookup_groups: - group_config: GroupOffloadConfig = self.config.kv_group_configs[group_idx] - offloaded_block_size = group_config.offloaded_block_size - offload_keys = req_status.group_states[group_idx].offload_keys - - num_blocks = max_hit_size_tokens // offloaded_block_size - assert len(offload_keys) >= num_blocks - - # Constrain to block-aligned boundary for this group - max_hit_size_tokens = num_blocks * offloaded_block_size - num_hit_tokens = max_hit_size_tokens - num_computed_tokens - if num_hit_tokens < offloaded_block_size: - # we can only load less than a block, better skip - return 0, False - - start_block_idx = num_computed_tokens // offloaded_block_size - offload_keys = offload_keys[start_block_idx:num_blocks] - # Full attention relies on all previous KV cache blocks. - # Thus, we search for a maximal prefix of KV cache which are all cached. - block_hits = self._maximal_prefix_lookup(offload_keys, req_status.req_context) - if block_hits == 0: - return 0, False - - if block_hits is None: - defer_lookup = True - else: - # Further constrain based on what's actually available by backend - max_hit_size_tokens = offloaded_block_size * (start_block_idx + block_hits) - - num_hit_tokens = max_hit_size_tokens - num_computed_tokens - if num_hit_tokens < offloaded_block_size: - # we can only load less than a block, better skip - return 0, False - - if ( - block_hits - and self._blocks_being_loaded - and any(key in self._blocks_being_loaded for key in offload_keys[:block_hits]) - ): - # hit blocks are being loaded, delay request - delay_request = True - - if defer_lookup: - logger.debug( - "Offloading manager delayed request %s as backend requested", - req_status.req.request_id, - ) - return None, False - - if delay_request: - logger.debug( - "Delaying request %s since some of its blocks are already being loaded", - req_status.req.request_id, - ) - return None, False + num_hit_tokens = self._lookup(req_status) + if is_new_request: + req_status.update_num_hit_blocks(num_computed_tokens + (num_hit_tokens or 0)) - logger.debug( - "Request %s hit %s offloaded tokens after %s GPU hit tokens", - request.request_id, - num_hit_tokens, - num_computed_tokens, - ) + self._touch(req_status) - return num_hit_tokens, True + return num_hit_tokens, bool(num_hit_tokens) def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, num_external_tokens: int): if num_external_tokens == 0: @@ -317,6 +475,11 @@ def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, num_ assert num_locally_computed_tokens <= num_locally_computed_gpu_blocks * gpu_block_size num_pending_gpu_blocks = num_gpu_blocks - num_locally_computed_gpu_blocks + if group_config.sliding_window_size_in_blocks is not None: + assert ( + num_pending_gpu_blocks <= group_config.sliding_window_size_in_blocks * self.config.block_size_factor + ) + num_blocks = cdiv(num_cached_tokens, offloaded_block_size) assert len(offload_keys) >= num_blocks if num_pending_gpu_blocks: @@ -335,19 +498,39 @@ def update_state_after_alloc(self, request: Request, blocks: KVCacheBlocks, num_ # entire KV cache so a remote decode node can consume it. group_state.next_stored_block_idx = num_blocks + # Fence dst blocks against finished-request pending stores. + if self._block_id_to_pending_jobs and not self._block_id_to_pending_jobs.keys().isdisjoint(dst_block_ids): + self._current_batch_jobs_to_flush.update( + jid for bid in dst_block_ids for jid in self._block_id_to_pending_jobs.get(bid, ()) + ) + src_spec = self.manager.prepare_load(keys_to_load, req_status.req_context) dst_spec = GPULoadStoreSpec(dst_block_ids, group_sizes=group_sizes, block_indices=block_indices) - self._reqs_to_load[request.request_id] = (src_spec, dst_spec) - req_blocks_being_loaded = self._reqs_being_loaded[request.request_id] - req_blocks_being_loaded.update(keys_to_load) + load_job_id = self._generate_job_id() + self._current_batch_load_jobs[load_job_id] = TransferJob( + req_id=request.request_id, + transfer_spec=(src_spec, dst_spec), + ) + # a load can only be issued when no other jobs are pending. + assert not req_status.transfer_jobs + req_status.transfer_jobs.add(load_job_id) + self._jobs[load_job_id] = TransferJobStatus( + req_id=request.request_id, + pending_count=self.config.num_workers, + keys=set(keys_to_load), + is_store=False, + ) if self._blocks_being_loaded is not None: - self._blocks_being_loaded.update(req_blocks_being_loaded) + self._blocks_being_loaded.update(keys_to_load) - def _get_reqs_to_store(self, scheduler_output: SchedulerOutput) -> dict[ReqId, TransferSpec]: + def _build_store_jobs( + self, + scheduler_output: SchedulerOutput, + ) -> dict[int, TransferJob]: block_size_factor = self.config.block_size_factor - reqs_to_store: dict[ReqId, TransferSpec] = {} + store_jobs: dict[int, TransferJob] = {} # iterate over both new and cached requests for req_id, new_block_id_groups, preempted in yield_req_data(scheduler_output): req_status = self._req_status[req_id] @@ -360,6 +543,13 @@ def _get_reqs_to_store(self, scheduler_output: SchedulerOutput) -> dict[ReqId, T if new_block_id_groups: req_status.update_block_id_groups(new_block_id_groups) + # Fence new blocks against in-flight stores. + if self._block_id_to_pending_jobs: + new_blocks_flat = [bid for new_blocks in new_block_id_groups for bid in new_blocks] + if not self._block_id_to_pending_jobs.keys().isdisjoint(new_blocks_flat): + self._current_batch_jobs_to_flush.update( + jid for bid in new_blocks_flat for jid in self._block_id_to_pending_jobs.get(bid, ()) + ) num_scheduled_tokens = scheduler_output.num_scheduled_tokens[req_id] num_tokens_after_batch = req.num_computed_tokens + num_scheduled_tokens @@ -405,15 +595,17 @@ def _get_reqs_to_store(self, scheduler_output: SchedulerOutput) -> dict[ReqId, T req_status.advance_stored_idx(num_offloadable_tokens) continue - for group_state in req_status.group_states: - self.manager.touch(group_state.offload_keys) + self._touch(req_status) keys_to_store = set(store_output.keys_to_store) group_sizes: list[int] = [] block_indices: list[int] = [] src_block_ids: list[int] = [] + sliding_window_block_ids: list[int] = [] + non_sliding_window_block_ids: list[int] = [] for group_config, group_state in zip(self.config.kv_group_configs, req_status.group_states): + is_sliding_window = group_config.sliding_window_size_in_blocks is not None num_blocks = num_offloadable_tokens // group_config.offloaded_block_size start_block_idx = group_state.next_stored_block_idx block_ids = group_state.block_ids @@ -435,6 +627,11 @@ def _get_reqs_to_store(self, scheduler_output: SchedulerOutput) -> dict[ReqId, T elif start_gpu_block_idx is None: start_gpu_block_idx = gpu_block_idx + i src_block_ids.append(block_id) + if is_sliding_window: + sliding_window_block_ids.append(block_id) + else: + non_sliding_window_block_ids.append(block_id) + group_sizes.append(num_group_blocks) block_indices.append(start_gpu_block_idx or 0) group_state.next_stored_block_idx = num_blocks @@ -442,34 +639,57 @@ def _get_reqs_to_store(self, scheduler_output: SchedulerOutput) -> dict[ReqId, T src_spec = GPULoadStoreSpec(src_block_ids, group_sizes=group_sizes, block_indices=block_indices) dst_spec = store_output.store_spec - reqs_to_store[req_id] = (src_spec, dst_spec) - self._reqs_being_stored[req_id] |= keys_to_store + job_id = self._generate_job_id() + # a store can only be issued when no load is pending. + if req_status.transfer_jobs: + any_jid = next(iter(req_status.transfer_jobs)) + assert self._jobs[any_jid].is_store + req_status.transfer_jobs.add(job_id) + + # Watch sliding window blocks as they may get evicted + # before the request finishes + for bid in sliding_window_block_ids or (): + self._block_id_to_pending_jobs.setdefault(bid, set()).add(job_id) + + # the non-sliding window blocks will be watched only + # when the request finishes + self._jobs[job_id] = TransferJobStatus( + req_id=req_id, + pending_count=self.config.num_workers, + keys=set(keys_to_store), + is_store=True, + non_sliding_window_block_ids=non_sliding_window_block_ids, + sliding_window_block_ids=sliding_window_block_ids or None, + ) + + store_jobs[job_id] = TransferJob(req_id=req_id, transfer_spec=(src_spec, dst_spec)) logger.debug( - "Request %s offloading %s blocks upto %d tokens", + "Request %s offloading %s blocks upto %d tokens (job %d)", req_id, len(keys_to_store), num_offloadable_tokens, + job_id, ) - return reqs_to_store + return store_jobs def build_connector_meta(self, scheduler_output: SchedulerOutput) -> KVConnectorMetadata: - meta = OffloadingConnectorMetadata( - reqs_to_load=self._reqs_to_load, - reqs_to_store=self._get_reqs_to_store(scheduler_output), - reqs_to_flush=scheduler_output.preempted_req_ids, - ) - self._reqs_to_load = {} - - # NOTE (orozery): we should move this logic to update_connector_output - # once KVConnectorOutput allows us to report completed transfers for req_id in scheduler_output.preempted_req_ids or (): - keys = self._reqs_being_stored.get(req_id) - if keys: - self.manager.complete_store(keys) - keys.clear() + req_status = self._req_status.get(req_id) + if req_status is None or not req_status.transfer_jobs: + continue + any_jid = next(iter(req_status.transfer_jobs)) + assert self._jobs[any_jid].is_store + self._current_batch_jobs_to_flush.update(req_status.transfer_jobs) + meta = OffloadingConnectorMetadata( + load_jobs=self._current_batch_load_jobs, + store_jobs=self._build_store_jobs(scheduler_output), + jobs_to_flush=self._current_batch_jobs_to_flush, + ) + self._current_batch_load_jobs = {} + self._current_batch_jobs_to_flush = set() return meta def update_connector_output(self, connector_output: KVConnectorOutput): @@ -480,22 +700,43 @@ def update_connector_output(self, connector_output: KVConnectorOutput): connector_output (KVConnectorOutput): the worker-side connectors output. """ - for req_id in connector_output.finished_sending or []: - keys = self._reqs_being_stored.pop(req_id, None) - if keys: - self.manager.complete_store(keys) - - for req_id in connector_output.finished_recving or []: - keys = self._reqs_being_loaded.pop(req_id, None) - if keys: + meta = connector_output.kv_connector_worker_meta + if not isinstance(meta, OffloadingWorkerMetadata): + assert meta is None + meta = OffloadingWorkerMetadata() + for job_id, count in meta.completed_jobs.items(): + assert count > 0 + job_status = self._jobs[job_id] + job_status.pending_count -= count + if job_status.pending_count > 0: + continue + assert job_status.pending_count == 0 + + if job_status.is_store: + self.manager.complete_store(job_status.keys) + else: + self.manager.complete_load(job_status.keys) if self._blocks_being_loaded: - self._blocks_being_loaded.difference_update(keys) - self.manager.complete_load(keys) + self._blocks_being_loaded.difference_update(job_status.keys) + + req_status = self._req_status[job_status.req_id] + if self._block_id_to_pending_jobs: + # Sliding window blocks are tracked from store creation + # and must be cleaned up unconditionally. + self._remove_pending_job(job_id, job_status.sliding_window_block_ids) + # Non-sliding-window blocks are only tracked after + # request_finished, so only clean up for finished requests. + if req_status.req.is_finished(): + self._remove_pending_job(job_id, job_status.non_sliding_window_block_ids) + + del self._jobs[job_id] + req_status.transfer_jobs.remove(job_id) + if not req_status.transfer_jobs and req_status.req.is_finished(): + del self._req_status[job_status.req_id] def request_finished( self, request: Request, - block_ids: list[int], ) -> tuple[bool, dict[str, Any] | None]: """ Called when a request has finished, before its blocks are freed. @@ -507,14 +748,21 @@ def request_finished( Optional KVTransferParams to be included in the request outputs returned by the engine. """ - req_id = request.request_id - # TODO(orozery): possibly kickoff offload for last block # which may have been deferred due to async scheduling - self._req_status.pop(req_id, None) - - request_being_stored = req_id in self._reqs_being_stored - return request_being_stored, None + req_status = self._req_status.get(request.request_id) + if req_status is None: + return False, None + if not req_status.transfer_jobs: + del self._req_status[request.request_id] + return False, None + # Pending stores will outlive the request's block ownership. + # Register them so future block reuse triggers a flush. + for job_id in req_status.transfer_jobs: + job_status = self._jobs[job_id] + for bid in job_status.non_sliding_window_block_ids or (): + self._block_id_to_pending_jobs.setdefault(bid, set()).add(job_id) + return False, None def take_events(self) -> Iterable[KVCacheEvent]: """Take the KV cache events from the connector. diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/offloading/worker.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/offloading/worker.py index 7a2ab4669a..401af293bb 100644 --- a/aphrodite/distributed/kv_transfer/kv_connector/v1/offloading/worker.py +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/offloading/worker.py @@ -11,6 +11,7 @@ ) from aphrodite.distributed.kv_transfer.kv_connector.v1.offloading.common import ( OffloadingConnectorMetadata, + OffloadingWorkerMetadata, ReqId, ) from aphrodite.distributed.kv_transfer.kv_connector.v1.offloading.metrics import ( @@ -24,7 +25,7 @@ MambaSpec, UniformTypeKVCacheSpecs, ) -from aphrodite.v1.kv_offload.spec import ( +from aphrodite.v1.kv_offload.base import ( CanonicalKVCacheRef, CanonicalKVCaches, CanonicalKVCacheTensor, @@ -45,24 +46,11 @@ def __init__(self, spec: OffloadingSpec): self.spec = spec self.worker = OffloadingWorker() - self._job_counter = 0 - self.kv_connector_stats = OffloadingConnectorStats() - # req_id -> (job_id, store) - self._jobs: dict[int, tuple[ReqId, bool]] = {} - # req_id -> active job IDs - self._load_job: dict[ReqId, int] = {} - # req_id -> set(active job IDs) - self._store_jobs = defaultdict[ReqId, set[int]](set) - # list of store jobs pending submission (job_id, transfer_spec) + # job_id -> req_id for in-flight loads. + self._load_jobs: dict[int, ReqId] = {} self._unsubmitted_store_jobs: list[tuple[int, TransferSpec]] = [] - - self._finished_reqs_waiting_for_store: set[ReqId] = set() - - def _generate_job_id(self) -> int: - job_id = self._job_counter - self._job_counter = job_id + 1 - return job_id + self._connector_worker_meta = OffloadingWorkerMetadata() def _register_handlers(self, kv_caches: CanonicalKVCaches): for src_cls, dst_cls, handler in self.spec.get_handlers(kv_caches): @@ -274,10 +262,8 @@ def handle_preemptions(self, kv_connector_metadata: OffloadingConnectorMetadata) assert success self._unsubmitted_store_jobs.clear() - for req_id in kv_connector_metadata.reqs_to_flush or (): - job_ids = self._store_jobs.get(req_id) - if job_ids: - self.worker.wait(job_ids) + if kv_connector_metadata.jobs_to_flush: + self.worker.wait(kv_connector_metadata.jobs_to_flush) def start_kv_transfers(self, metadata: OffloadingConnectorMetadata): for job_id, transfer_spec in self._unsubmitted_store_jobs: @@ -285,41 +271,33 @@ def start_kv_transfers(self, metadata: OffloadingConnectorMetadata): assert success self._unsubmitted_store_jobs.clear() - for req_id, transfer_spec in metadata.reqs_to_load.items(): - job_id = self._generate_job_id() - self._jobs[job_id] = (req_id, False) - assert req_id not in self._load_job - self._load_job[req_id] = job_id - success = self.worker.transfer_async(job_id, transfer_spec) + for job_id, entry in metadata.load_jobs.items(): + self._load_jobs[job_id] = entry.req_id + success = self.worker.transfer_async(job_id, entry.transfer_spec) assert success def prepare_store_kv(self, metadata: OffloadingConnectorMetadata): - for req_id, transfer_spec in metadata.reqs_to_store.items(): - job_id = self._generate_job_id() - self._jobs[job_id] = (req_id, True) - self._store_jobs[req_id].add(job_id) - # NOTE(orozery): defer the store to the beginning of the next engine step, - # so that offloading starts AFTER transfers related to token sampling, - # thereby avoiding delays to token generation due to offloading. - self._unsubmitted_store_jobs.append((job_id, transfer_spec)) + for job_id, entry in metadata.store_jobs.items(): + # NOTE(orozery): defer the store to the beginning of the next + # engine step, so that offloading starts AFTER transfers related + # to token sampling, thereby avoiding delays to token generation. + self._unsubmitted_store_jobs.append((job_id, entry.transfer_spec)) def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: """ - Notifies worker-side connector ids of requests that have - finished generating tokens. - Returns a list of request IDs that finished loading or storing. - Returns: - ids of requests that have finished asynchronous transfer - tuple of (sending/saving ids, recving/loading ids). + tuple of (finished_sending, finished_recving). Stores never + emit finished_sending — the scheduler tracks store completion + via kv_connector_worker_meta.completed_jobs and fences any + block reuse via jobs_to_flush. Loads still emit + finished_recving so the base scheduler can resume requests + blocked on remote KV (and free aborted-during-load reqs). """ - finished_sending = set() - finished_recving = set() + finished_recving: set[str] = set() for transfer_result in self.worker.get_finished(): # we currently do not support job failures job_id = transfer_result.job_id assert transfer_result.success - req_id, store = self._jobs.pop(job_id) if ( transfer_result.transfer_time and transfer_result.transfer_size is not None @@ -330,31 +308,21 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: time=transfer_result.transfer_time, transfer_type=transfer_result.transfer_type, ) - if store: - req_jobs = self._store_jobs[req_id] - req_jobs.remove(job_id) - if req_jobs: - continue - - if req_id in self._finished_reqs_waiting_for_store: - self._finished_reqs_waiting_for_store.remove(req_id) - finished_sending.add(req_id) - del self._store_jobs[req_id] - else: - req_job = self._load_job[req_id] - assert job_id == req_job - del self._load_job[req_id] + + self._connector_worker_meta.mark_completed(job_id) + req_id = self._load_jobs.pop(job_id, None) + if req_id is not None: finished_recving.add(req_id) - for req_id in finished_req_ids: - pending_req_jobs = self._store_jobs.get(req_id) - if pending_req_jobs: - self._finished_reqs_waiting_for_store.add(req_id) - elif pending_req_jobs is not None: - finished_sending.add(req_id) - del self._store_jobs[req_id] + return set(), finished_recving - return finished_sending, finished_recving + def build_connector_worker_meta(self) -> OffloadingWorkerMetadata | None: + """Return completed transfer job IDs since the last call.""" + if not self._connector_worker_meta.completed_jobs: + return None + meta = self._connector_worker_meta + self._connector_worker_meta = OffloadingWorkerMetadata() + return meta def get_kv_connector_stats(self) -> KVConnectorStats | None: """ @@ -369,11 +337,7 @@ def get_kv_connector_stats(self) -> KVConnectorStats | None: return kv_connector_stats def shutdown(self) -> None: - # Drop deferred store jobs: there is no point in submitting - # them during shutdown. self._unsubmitted_store_jobs.clear() - self._jobs.clear() - self._load_job.clear() - self._store_jobs.clear() - self._finished_reqs_waiting_for_store.clear() + self._load_jobs.clear() + self._connector_worker_meta = OffloadingWorkerMetadata() self.worker.shutdown() diff --git a/aphrodite/distributed/kv_transfer/kv_connector/v1/offloading_connector.py b/aphrodite/distributed/kv_transfer/kv_connector/v1/offloading_connector.py index a272184f57..f856c6c5d2 100644 --- a/aphrodite/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +++ b/aphrodite/distributed/kv_transfer/kv_connector/v1/offloading_connector.py @@ -10,6 +10,7 @@ from aphrodite.distributed.kv_transfer.kv_connector.v1 import ( KVConnectorBase_V1, KVConnectorRole, + SupportsHMA, ) from aphrodite.distributed.kv_transfer.kv_connector.v1.base import KVConnectorMetadata from aphrodite.distributed.kv_transfer.kv_connector.v1.metrics import ( @@ -20,6 +21,7 @@ ) from aphrodite.distributed.kv_transfer.kv_connector.v1.offloading.common import ( OffloadingConnectorMetadata, + OffloadingWorkerMetadata, ) from aphrodite.distributed.kv_transfer.kv_connector.v1.offloading.metrics import ( OffloadingConnectorStats, @@ -41,7 +43,7 @@ from aphrodite.v1.request import Request -class OffloadingConnector(KVConnectorBase_V1): +class OffloadingConnector(KVConnectorBase_V1, SupportsHMA): @property def prefer_cross_layer_blocks(self) -> bool: return True @@ -109,6 +111,11 @@ def get_finished(self, finished_req_ids: set[str]) -> tuple[set[str], set[str]]: assert self.connector_worker is not None return self.connector_worker.get_finished(finished_req_ids) + def build_connector_worker_meta(self) -> OffloadingWorkerMetadata | None: + if self.connector_worker is not None: + return self.connector_worker.build_connector_worker_meta() + return None + def get_num_new_matched_tokens(self, request: "Request", num_computed_tokens: int) -> tuple[int | None, bool]: assert self.connector_scheduler is not None return self.connector_scheduler.get_num_new_matched_tokens(request, num_computed_tokens) @@ -131,7 +138,15 @@ def request_finished( block_ids: list[int], ) -> tuple[bool, dict[str, Any] | None]: assert self.connector_scheduler is not None - return self.connector_scheduler.request_finished(request, block_ids) + return self.connector_scheduler.request_finished(request) + + def request_finished_all_groups( + self, + request: "Request", + block_ids: tuple[list[int], ...], + ) -> tuple[bool, dict[str, Any] | None]: + assert self.connector_scheduler is not None + return self.connector_scheduler.request_finished(request) def take_events(self) -> Iterable[KVCacheEvent]: assert self.connector_scheduler is not None diff --git a/aphrodite/distributed/parallel_state.py b/aphrodite/distributed/parallel_state.py index cf93af7f1c..4bdf0f43b3 100644 --- a/aphrodite/distributed/parallel_state.py +++ b/aphrodite/distributed/parallel_state.py @@ -447,6 +447,7 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None # only cuda uses this function, # so we don't abstract it into the base class maybe_ca_context = nullcontext() + maybe_aiter_context = nullcontext() from aphrodite.distributed.device_communicators.cuda_communicator import ( CudaCommunicator, ) @@ -457,13 +458,20 @@ def graph_capture(self, graph_capture_context: GraphCaptureContext | None = None if ca_comm is not None: maybe_ca_context = ca_comm.capture() # type: ignore + from aphrodite._aiter_ops import rocm_aiter_ops + + if rocm_aiter_ops.is_enabled(): + aiter_ar = rocm_aiter_ops.get_aiter_allreduce() + if aiter_ar is not None: + maybe_aiter_context = aiter_ar.capture() # type: ignore + # ensure all initialization operations complete before attempting to # capture the graph on another stream curr_stream = torch.cuda.current_stream() if curr_stream != stream: stream.wait_stream(curr_stream) - with torch.cuda.stream(stream), maybe_ca_context: + with torch.cuda.stream(stream), maybe_ca_context, maybe_aiter_context: yield graph_capture_context def all_reduce(self, input_: torch.Tensor) -> torch.Tensor: diff --git a/aphrodite/engine/protocol.py b/aphrodite/engine/protocol.py index 3477293406..71440eee93 100644 --- a/aphrodite/engine/protocol.py +++ b/aphrodite/engine/protocol.py @@ -75,6 +75,7 @@ def generate( priority: int = 0, data_parallel_rank: int | None = None, reasoning_ended: bool | None = None, + reasoning_parser_kwargs: dict[str, Any] | None = None, ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request.""" ... diff --git a/aphrodite/entrypoints/anthropic/protocol.py b/aphrodite/entrypoints/anthropic/protocol.py index add0dfbffb..ad95ef800c 100644 --- a/aphrodite/entrypoints/anthropic/protocol.py +++ b/aphrodite/entrypoints/anthropic/protocol.py @@ -39,6 +39,7 @@ class AnthropicContentBlock(BaseModel): "image", "tool_use", "tool_result", + "tool_reference", "thinking", "redacted_thinking", ] @@ -52,6 +53,8 @@ class AnthropicContentBlock(BaseModel): input: dict[str, Any] | None = None content: str | list[dict[str, Any]] | None = None is_error: bool | None = None + # For tool_reference content + tool_name: str | None = None # For thinking content thinking: str | None = None signature: str | None = None @@ -72,6 +75,7 @@ class AnthropicTool(BaseModel): name: str description: str | None = None input_schema: dict[str, Any] + defer_loading: bool | None = None @field_validator("input_schema") @classmethod diff --git a/aphrodite/entrypoints/anthropic/serving.py b/aphrodite/entrypoints/anthropic/serving.py index c4e506be20..9e1eace211 100644 --- a/aphrodite/entrypoints/anthropic/serving.py +++ b/aphrodite/entrypoints/anthropic/serving.py @@ -233,6 +233,10 @@ def _convert_block( cls._convert_tool_use_block(block, tool_calls) elif block.type == "tool_result": cls._convert_tool_result_block(block, role, openai_messages, content_parts) + elif block.type == "tool_reference": + # Tool references are expanded during tool_result processing + # when they appear inside tool_result content. + pass @classmethod def _convert_tool_use_block(cls, block, tool_calls: list[dict[str, Any]]) -> None: @@ -267,6 +271,7 @@ def _convert_user_tool_result(cls, block, openai_messages: list[dict[str, Any]]) """Convert user tool_result with text and image support""" tool_text = "" tool_image_urls: list[str] = [] + tool_reference: list[dict[str, Any]] = [] if isinstance(block.content, str): tool_text = block.content @@ -283,6 +288,10 @@ def _convert_user_tool_result(cls, block, openai_messages: list[dict[str, Any]]) url = cls._convert_image_source_to_url(source) if url: tool_image_urls.append(url) + elif item_type == "tool_reference": + ref_name = item.get("tool_name") or item.get("name") + if ref_name: + tool_reference.append({"type": "tool_reference", "name": ref_name}) tool_text = "\n".join(text_parts) openai_messages.append( @@ -303,6 +312,15 @@ def _convert_user_tool_result(cls, block, openai_messages: list[dict[str, Any]]) } ) + if tool_reference: + openai_messages.append( + { + "role": "tool", + "tool_call_id": block.tool_use_id or "", + "content": tool_reference, # type: ignore[dict-item] + } + ) + @classmethod def _build_base_request( cls, @@ -389,6 +407,7 @@ def _convert_tools( "name": tool.name, "description": tool.description, "parameters": tool.input_schema, + "defer_loading": tool.defer_loading, }, } ) diff --git a/aphrodite/entrypoints/chat_utils.py b/aphrodite/entrypoints/chat_utils.py index c8406dcfb5..26f7ec98a9 100644 --- a/aphrodite/entrypoints/chat_utils.py +++ b/aphrodite/entrypoints/chat_utils.py @@ -11,7 +11,7 @@ from functools import cached_property, lru_cache, partial from itertools import accumulate from pathlib import Path -from typing import TYPE_CHECKING, Any, Generic, Literal, TypeAlias, TypeVar, cast +from typing import TYPE_CHECKING, Any, Final, Generic, Literal, TypeAlias, TypeVar, cast from openai.types.chat import ( ChatCompletionAssistantMessageParam, @@ -36,10 +36,11 @@ from pydantic import BaseModel, ConfigDict, TypeAdapter # pydantic needs the TypedDict from typing_extensions -from typing_extensions import Required, TypedDict +from typing_extensions import Required, TypedDict, override from aphrodite import envs from aphrodite.config import ModelConfig +from aphrodite.exceptions import APHRODITEValidationError from aphrodite.inputs import MultiModalDataDict, MultiModalUUIDDict from aphrodite.logger import init_logger from aphrodite.model_executor.models import SupportsMultiModal @@ -54,6 +55,10 @@ ) from aphrodite.multimodal.media import MEDIA_CONNECTOR_REGISTRY, MediaConnector from aphrodite.multimodal.processing import BaseMultiModalProcessor +from aphrodite.renderers.embed_utils import ( + safe_load_prompt_embeds, + safe_load_prompt_embeds_async, +) from aphrodite.utils import random_uuid from aphrodite.utils.collection_utils import is_list_of from aphrodite.utils.import_utils import LazyLoader @@ -97,9 +102,36 @@ class ChatTemplateResolutionError(ValueError): "image": "<##IMAGE##>", "audio": "<##AUDIO##>", "video": "<##VIDEO##>", + "prompt_embeds": "<##PROMPT_EMBEDS##>", } +PROMPT_EMBEDS_PLACEHOLDER_TOKEN: Final[str] = "" +"""The special token used as a placeholder for each embedding +position during chat template rendering. + +Registered as an additional special token when `--enable-prompt-embeds` is set. +See `_ensure_prompt_embeds_placeholder_token` in `aphrodite/renderers/hf.py`. +""" + + +_REQUIRE_MM_PROCESSOR_ERROR: Final[str] = ( + "Resolving modality {modality!r} requires a multimodal processor but none is available." +) + +_ENABLE_PROMPT_EMBEDS_ERROR: Final[str] = "You must set `--enable-prompt-embeds` to input `prompt_embeds`" + +_PROMPT_EMBEDS_MISSING_DATA_ERROR: Final[str] = ( + "prompt_embeds content part requires a non-empty `data` field with base64-encoded tensor bytes." +) + +_RESERVED_PLACEHOLDER_IN_TEXT_ERROR: Final[str] = ( + "Text content may not contain the reserved placeholder {token!r}. " + "This placeholder is used internally to mark `prompt_embeds` splice " + "positions in the tokenized prompt." +) + + class AudioURL(TypedDict, total=False): url: Required[str] """ @@ -146,6 +178,17 @@ class ChatCompletionContentPartAudioEmbedsParam(TypedDict, total=False): """ +class ChatCompletionContentPartPromptEmbedsParam(TypedDict, total=False): + data: Required[str] + """ + Base64-encoded bytes of a serialized `torch.Tensor` of shape + `(num_tokens, hidden_size)`. The tensor's `dtype` and `hidden_size` must + match the model's input embedding layer. + """ + type: Required[Literal["prompt_embeds"]] + """The type of the content part.""" + + class VideoURL(TypedDict, total=False): url: Required[str] """ @@ -254,6 +297,23 @@ class CustomThinkCompletionContentParam(TypedDict, total=False): """The thinking type.""" +class CustomChatCompletionContentToolReferenceParam(TypedDict, total=False): + """A tool reference content param that only accepts a plain tool name. + + Example: + { + "name": "get_weather", + "type": "tool_reference" + } + """ + + name: str + """The name of the tool being referenced.""" + + type: Literal["tool_reference"] + """The content type.""" + + ChatCompletionContentPartParam: TypeAlias = ( OpenAIChatCompletionContentPartParam | ChatCompletionContentPartAudioParam @@ -264,8 +324,10 @@ class CustomThinkCompletionContentParam(TypedDict, total=False): | CustomChatCompletionContentSimpleImageParam | ChatCompletionContentPartImageEmbedsParam | ChatCompletionContentPartAudioEmbedsParam + | ChatCompletionContentPartPromptEmbedsParam | CustomChatCompletionContentSimpleAudioParam | CustomChatCompletionContentSimpleVideoParam + | CustomChatCompletionContentToolReferenceParam | str | CustomThinkCompletionContentParam ) @@ -345,7 +407,15 @@ class ConversationMessage(TypedDict, total=False): ChatTemplateContentFormat = Literal["string", "openai"] -ModalityStr = Literal["image", "audio", "video", "image_embeds", "audio_embeds", "vision_chunk"] +ModalityStr = Literal[ + "image", + "audio", + "video", + "image_embeds", + "audio_embeds", + "vision_chunk", + "prompt_embeds", +] _T = TypeVar("_T") @@ -503,7 +573,17 @@ def add(self, modality: ModalityStr, item: _T) -> str | None: An optional uuid can be added which serves as a unique identifier of the media. + + Note: + `prompt_embeds` bypass MM-processor validation because they are + pre-computed embeddings that do not go through any HF processor, encoder, + or model-specific placeholder logic. The corresponding placeholder string is + managed by the parser via `_add_placeholder`, so we return None here. """ + if modality == "prompt_embeds": + self._items_by_modality["prompt_embeds"].append(item) + return None + input_modality = modality.replace("_embeds", "") original_modality = modality use_vision_chunk = self.use_unified_vision_chunk_modality and original_modality in ["video", "image"] @@ -605,17 +685,30 @@ def _resolve_vision_chunk_items( def _resolve_items( items_by_modality: dict[str, list[tuple[object, str | None]]], - mm_processor: BaseMultiModalProcessor, + mm_processor: BaseMultiModalProcessor | None, modality_order: dict[str, list[str]], ) -> tuple[MultiModalDataDict, MultiModalUUIDDict]: + """ + Materialize the tracker's per-modality items into `mm_data` / `mm_uuids`. + + Note: + `mm_processor` is `None` for text-only models (no registered HF + processor) whose only modality is `prompt_embeds`. Every other + modality requires a processor, enforced by the guard below. + """ if "image" in items_by_modality and "image_embeds" in items_by_modality: raise ValueError("Mixing raw image and embedding inputs is not allowed") if "audio" in items_by_modality and "audio_embeds" in items_by_modality: raise ValueError("Mixing raw audio and embedding inputs is not allowed") + # `prompt_embeds` bypasses HF MM processors. Every other modality requires one. + processor_modalities = items_by_modality.keys() - {"prompt_embeds"} + if processor_modalities and mm_processor is None: + raise RuntimeError(_REQUIRE_MM_PROCESSOR_ERROR.format(modality=processor_modalities)) mm_data = {} mm_uuids = {} if "image_embeds" in items_by_modality: + assert mm_processor is not None mm_data["image"] = _get_embeds_data( "image", [data for data, uuid in items_by_modality["image_embeds"]], @@ -626,6 +719,7 @@ def _resolve_items( mm_data["image"] = [data for data, uuid in items_by_modality["image"]] mm_uuids["image"] = [uuid for data, uuid in items_by_modality["image"]] if "audio_embeds" in items_by_modality: + assert mm_processor is not None mm_data["audio"] = _get_embeds_data( "audio", [data for data, uuid in items_by_modality["audio_embeds"]], @@ -639,6 +733,7 @@ def _resolve_items( mm_data["video"] = [data for data, uuid in items_by_modality["video"]] mm_uuids["video"] = [uuid for data, uuid in items_by_modality["video"]] if "vision_chunk" in items_by_modality: + assert mm_processor is not None # Process vision_chunk items - extract from (data, modality) tuples # and convert to VisionChunk types with proper UUID handling processed_chunks, vision_chunk_uuids = _resolve_vision_chunk_items( @@ -648,6 +743,8 @@ def _resolve_items( ) mm_data["vision_chunk"] = processed_chunks mm_uuids["vision_chunk"] = vision_chunk_uuids + if "prompt_embeds" in items_by_modality: + mm_data["prompt_embeds"] = [data for data, _uuid in items_by_modality["prompt_embeds"]] return mm_data, mm_uuids @@ -659,7 +756,15 @@ def resolve_items( if not self._items_by_modality: return None, None - return _resolve_items(dict(self._items_by_modality), self.mm_processor, self._modality_order) + # Text-only models (`is_multimodal_model=False`) with inputs of + # modality `prompt_embeds` have no MM processor since `prompt_embeds` are + # pre-computed and require no processing, so we pass `None`. + mm_processor = self.mm_processor if self._model_config.is_multimodal_model else None + return _resolve_items( + dict(self._items_by_modality), + mm_processor, + self._modality_order, + ) def create_parser(self, mm_processor_kwargs: dict[str, Any] | None = None) -> "BaseMultiModalContentParser": return MultiModalContentParser(self, mm_processor_kwargs=mm_processor_kwargs) @@ -676,7 +781,12 @@ async def resolve_items( modality: await asyncio.gather(*coros) for modality, coros in self._items_by_modality.items() } - return _resolve_items(resolved_items_by_modality, self.mm_processor, self._modality_order) + mm_processor = self.mm_processor if self._model_config.is_multimodal_model else None + return _resolve_items( + resolved_items_by_modality, + mm_processor, + self._modality_order, + ) def create_parser(self, mm_processor_kwargs: dict[str, Any] | None = None) -> "BaseMultiModalContentParser": return AsyncMultiModalContentParser(self, mm_processor_kwargs=mm_processor_kwargs) @@ -690,10 +800,16 @@ def __init__(self) -> None: # general MM placeholder: # { # "<##IMAGE##>": ["", "", ""], - # "<##AUDIO##>": ["