Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 37 additions & 18 deletions csrc/moe/fp32_router_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ __device__ __forceinline__ void load_activation<__nv_bfloat16, 8>(
// Weight is always fp32; output is always fp32.
// VPT = 16 / sizeof(InputT): 4 for fp32, 8 for bf16
template <typename InputT, int kBlockSize, int kNumTokens, int kNumExperts,
int kHiddenDim>
int kHiddenDim, bool ENABLE_PDL>
__global__ __launch_bounds__(128, 1) void fp32_router_gemm_kernel(
float* out, InputT const* mat_a, float const* mat_b) {
constexpr int VPT = 16 / sizeof(InputT);
Expand All @@ -103,9 +103,11 @@ __global__ __launch_bounds__(128, 1) void fp32_router_gemm_kernel(
k_bases[ki] = ki * k_elems_per_k_iteration + tid * VPT;
}

#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
asm volatile("griddepcontrol.launch_dependents;");
asm volatile("griddepcontrol.wait;");
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) && \
defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if constexpr (ENABLE_PDL) {
asm volatile("griddepcontrol.wait;");
}
#endif

for (int ki = 0; ki < k_iterations; ki++) {
Expand Down Expand Up @@ -149,6 +151,14 @@ __global__ __launch_bounds__(128, 1) void fp32_router_gemm_kernel(
out[m * kNumExperts + n_idx] = final_sum;
}
}

#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) && \
defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if constexpr (ENABLE_PDL) {
__syncthreads();
asm volatile("griddepcontrol.launch_dependents;");
}
#endif
}

// ---------------------------------------------------------------------------
Expand All @@ -159,20 +169,29 @@ template <typename InputT, int kNumTokens, int kNumExperts, int kHiddenDim>
void invokeFp32RouterGemm(float* output, InputT const* mat_a,
float const* mat_b, cudaStream_t stream) {
constexpr int kBlockSize = 128;
cudaLaunchConfig_t config;
config.gridDim = kNumExperts;
config.blockDim = kBlockSize;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config,
fp32_router_gemm_kernel<InputT, kBlockSize, kNumTokens,
kNumExperts, kHiddenDim>,
output, mat_a, mat_b);
#if defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
if (getEnvEnablePDL()) {
cudaLaunchConfig_t config;
config.gridDim = kNumExperts;
config.blockDim = kBlockSize;
config.dynamicSmemBytes = 0;
config.stream = stream;
cudaLaunchAttribute attrs[1];
attrs[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
attrs[0].val.programmaticStreamSerializationAllowed = 1;
config.numAttrs = 1;
config.attrs = attrs;
cudaLaunchKernelEx(&config,
fp32_router_gemm_kernel<InputT, kBlockSize, kNumTokens,
kNumExperts, kHiddenDim, true>,
output, mat_a, mat_b);
return;
}
#endif

fp32_router_gemm_kernel<InputT, kBlockSize, kNumTokens, kNumExperts,
kHiddenDim, false>
<<<kNumExperts, kBlockSize, 0, stream>>>(output, mat_a, mat_b);
}

// ---------------------------------------------------------------------------
Expand Down
23 changes: 21 additions & 2 deletions csrc/moe/fp32_router_gemm_entry.cu
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/all.h>
#include <cuda_bf16.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -54,11 +55,24 @@ void fp32_router_gemm(at::Tensor& output, // [num_tokens, num_experts]
const at::Tensor& mat_b // [num_experts, hidden_dim]
) {
TORCH_CHECK(output.dim() == 2 && mat_a.dim() == 2 && mat_b.dim() == 2);
TORCH_CHECK(output.is_cuda() && mat_a.is_cuda() && mat_b.is_cuda(),
"fp32_router_gemm: all tensors must be CUDA tensors");
TORCH_CHECK(output.get_device() == mat_a.get_device() &&
output.get_device() == mat_b.get_device(),
"fp32_router_gemm: all tensors must be on the same CUDA device");
TORCH_CHECK(output.is_contiguous() && mat_a.is_contiguous() &&
mat_b.is_contiguous(),
"fp32_router_gemm: all tensors must be contiguous");

const int num_tokens = mat_a.size(0);
const int num_experts = mat_b.size(0);
const int hidden_dim = mat_a.size(1);

TORCH_CHECK(output.size(0) == num_tokens && output.size(1) == num_experts,
"fp32_router_gemm: output must have shape [num_tokens, "
"num_experts], got [",
output.size(0), ", ", output.size(1), "], expected [",
num_tokens, ", ", num_experts, "]");
TORCH_CHECK(
mat_a.size(1) == mat_b.size(1),
"fp32_router_gemm: mat_a and mat_b must have the same hidden_dim");
Expand All @@ -68,8 +82,8 @@ void fp32_router_gemm(at::Tensor& output, // [num_tokens, num_experts]
TORCH_CHECK(num_experts == FP32_NUM_EXPERTS,
"fp32_router_gemm: expected num_experts=", FP32_NUM_EXPERTS,
", got ", num_experts);
TORCH_CHECK(num_tokens >= 1 && num_tokens <= FP32_MAX_TOKENS,
"fp32_router_gemm: num_tokens must be in [1, ", FP32_MAX_TOKENS,
TORCH_CHECK(num_tokens <= FP32_MAX_TOKENS,
"fp32_router_gemm: num_tokens must be in [0, ", FP32_MAX_TOKENS,
"], got ", num_tokens);
TORCH_CHECK(mat_a.dtype() == at::kFloat || mat_a.dtype() == at::kBFloat16,
"fp32_router_gemm: mat_a must be float32 or bfloat16");
Expand All @@ -78,6 +92,11 @@ void fp32_router_gemm(at::Tensor& output, // [num_tokens, num_experts]
TORCH_CHECK(output.dtype() == at::kFloat,
"fp32_router_gemm: output must be float32");

if (num_tokens == 0) {
return;
}

const at::cuda::OptionalCUDAGuard device_guard(device_of(mat_a));
const int sm = getSMVersion();
TORCH_CHECK(sm >= 90, "fp32_router_gemm: requires SM90+, got SM", sm);

Expand Down
22 changes: 18 additions & 4 deletions csrc/moe/topk_softmax_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,13 @@ __device__ __forceinline__ float toFloat(T value) {
}
}

#ifndef USE_ROCM
inline bool supportsPdlOnCurrentDevice() {
const auto* props = at::cuda::getCurrentDeviceProperties();
return props != nullptr && props->major >= 9;
}
#endif

// Scoring function enums
enum ScoringFunc {
SCORING_SOFTMAX = 0, // apply softmax
Expand Down Expand Up @@ -315,9 +322,9 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
const int thread_row = warp_base_row + thread_row_in_warp;


#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) && \
defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if constexpr (ENABLE_PDL) {
asm volatile("griddepcontrol.launch_dependents;");
asm volatile("griddepcontrol.wait;");
}
#endif
Expand Down Expand Up @@ -569,6 +576,13 @@ __launch_bounds__(WARPS_PER_CTA* WARP_SIZE_PARAM) __global__
}
}

#if !defined(USE_ROCM) && defined(CUDA_VERSION) && (CUDA_VERSION >= 12000) && \
defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)
if constexpr (ENABLE_PDL) {
asm volatile("griddepcontrol.launch_dependents;");
}
#endif

}

namespace detail
Expand Down Expand Up @@ -599,8 +613,8 @@ void topkGatingLauncherHelper(const InputType* input, const bool* finished, floa
const int num_blocks = (num_warps + WARPS_PER_TB - 1) / WARPS_PER_TB;

dim3 block_dim(WARP_SIZE_PARAM, WARPS_PER_TB);
#ifndef USE_ROCM
if (enable_pdl) {
#if !defined(USE_ROCM) && defined(CUDA_VERSION) && (CUDA_VERSION >= 12000)
if (enable_pdl && supportsPdlOnCurrentDevice()) {
cudaLaunchConfig_t config;
config.gridDim = num_blocks;
Comment thread
qianlihuang marked this conversation as resolved.
config.blockDim = block_dim;
Expand Down
11 changes: 11 additions & 0 deletions vllm/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2300,6 +2300,17 @@ def gpt_oss_router_gemm(
return output


if hasattr(torch.ops, "_moe_C") and hasattr(torch.ops._moe_C, "fp32_router_gemm"):

@register_fake("_moe_C::fp32_router_gemm")
def fp32_router_gemm_fake(
output: torch.Tensor,
mat_a: torch.Tensor,
mat_b: torch.Tensor,
) -> None:
return


def topk_softmax(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
Expand Down
4 changes: 4 additions & 0 deletions vllm/envs.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@
VLLM_MAIN_CUDA_VERSION: str = "12.9"
VLLM_FLOAT32_MATMUL_PRECISION: Literal["highest", "high", "medium"] = "highest"
VLLM_BATCH_INVARIANT: bool = False
TRTLLM_ENABLE_PDL: bool = False
MAX_JOBS: str | None = None
NVCC_THREADS: str | None = None
VLLM_USE_PRECOMPILED: bool = False
Expand Down Expand Up @@ -500,6 +501,9 @@ def _get_or_set_default() -> str:
# Enable batch-invariant mode: deterministic results regardless of
# batch composition. Requires NVIDIA GPU with compute capability >= 9.0.
"VLLM_BATCH_INVARIANT": lambda: bool(int(os.getenv("VLLM_BATCH_INVARIANT", "0"))),
# Enable Programmatic Dependent Launch for supported NVIDIA MoE router
# kernels. Requires CUDA >= 12.0 and compute capability >= 9.0.
"TRTLLM_ENABLE_PDL": lambda: bool(int(os.getenv("TRTLLM_ENABLE_PDL", "0"))),
# Maximum number of compilation jobs to run in parallel.
# By default this is the number of CPUs
"MAX_JOBS": lambda: os.getenv("MAX_JOBS", None),
Expand Down
4 changes: 4 additions & 0 deletions vllm/model_executor/layers/fused_moe/layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,8 @@ class FusedMoE(CustomOp):
quant_config: Quantization configure.
enable_eplb: Whether to enable expert parallelism load balancer.
router_logits_dtype: Data type for router logits buffers.
enable_router_pdl: Whether fused top-k routing kernels should join a
Programmatic Dependent Launch chain.
"""

# --8<-- [end:fused_moe]
Expand Down Expand Up @@ -272,6 +274,7 @@ def __init__(
gate: torch.nn.Module | None = None,
shared_experts: torch.nn.Module | None = None,
routed_input_transform: torch.nn.Module | None = None,
enable_router_pdl: bool = False,
):
super().__init__()

Expand Down Expand Up @@ -462,6 +465,7 @@ def __init__(
# TODO(bnell): once we can construct the MK at init time, we
# can make this a value.
indices_type_getter=lambda: self.quant_method.topk_indices_dtype,
enable_pdl=enable_router_pdl,
)
self.routing_method_type: RoutingMethodType = self.router.routing_method_type

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def vllm_topk_sigmoid(
gating_output: torch.Tensor,
renormalize: bool = False,
e_score_correction_bias: torch.Tensor | None = None,
enable_pdl: bool = True, # FIXME
enable_pdl: bool = False,
) -> tuple[torch.Tensor, ...]:
ops.topk_sigmoid(
topk_weights,
Expand Down Expand Up @@ -81,6 +81,7 @@ def fused_topk_bias(
renormalize: bool,
scoring_func: str = "softmax",
indices_type: torch.dtype | None = None,
enable_pdl: bool = False,
):
if not rocm_aiter_ops.is_fused_moe_enabled():
assert hidden_states.size(0) == gating_output.size(0), (
Expand Down Expand Up @@ -110,6 +111,7 @@ def fused_topk_bias(
gating_output,
renormalize,
e_score_correction_bias,
enable_pdl,
)
return topk_weights, topk_ids
elif scoring_func == "sigmoid":
Expand All @@ -120,6 +122,7 @@ def fused_topk_bias(
gating_output,
renormalize,
e_score_correction_bias,
enable_pdl,
)
return topk_weights, topk_ids
else:
Expand Down Expand Up @@ -186,6 +189,7 @@ def __init__(
routed_scaling_factor: float = 1.0,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
enable_pdl: bool = False,
):
super().__init__(
top_k=top_k,
Expand All @@ -198,6 +202,7 @@ def __init__(
self.renormalize = renormalize
self.scoring_func = scoring_func
self.routed_scaling_factor = routed_scaling_factor
self.enable_pdl = enable_pdl

@property
def routing_method_type(self) -> RoutingMethodType:
Expand All @@ -224,6 +229,7 @@ def _compute_routing(
renormalize=self.renormalize,
scoring_func=self.scoring_func,
indices_type=indices_type,
enable_pdl=self.enable_pdl,
)

if self.routed_scaling_factor != 1.0:
Expand Down
30 changes: 22 additions & 8 deletions vllm/model_executor/layers/fused_moe/router/fused_topk_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def fused_topk(
renormalize: bool,
indices_type: torch.dtype | None = None,
scoring_func: str = "softmax",
enable_pdl: bool = False,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert hidden_states.size(0) == gating_output.size(0), "Number of tokens mismatch"

Expand All @@ -96,20 +97,30 @@ def fused_topk(
)

if scoring_func == "softmax":
topk_func = dispatch_topk_softmax_func(
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
)
use_rocm_aiter = rocm_aiter_ops.is_fused_moe_enabled()
topk_func = dispatch_topk_softmax_func(use_rocm_aiter=use_rocm_aiter)
pdl_kwargs = {} if use_rocm_aiter else {"enable_pdl": enable_pdl}
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
**pdl_kwargs,
)

return topk_weights, topk_ids, token_expert_indices
elif scoring_func == "sigmoid":
topk_func = dispatch_topk_sigmoid_func(
use_rocm_aiter=rocm_aiter_ops.is_fused_moe_enabled()
)
use_rocm_aiter = rocm_aiter_ops.is_fused_moe_enabled()
topk_func = dispatch_topk_sigmoid_func(use_rocm_aiter=use_rocm_aiter)
pdl_kwargs = {} if use_rocm_aiter else {"enable_pdl": enable_pdl}
topk_weights, topk_ids = topk_func(
topk_weights, topk_ids, token_expert_indices, gating_output, renormalize
topk_weights,
topk_ids,
token_expert_indices,
gating_output,
renormalize,
**pdl_kwargs,
)

return topk_weights, topk_ids, token_expert_indices
Expand All @@ -129,6 +140,7 @@ def __init__(
renormalize: bool = True,
enable_eplb: bool = False,
indices_type_getter: Callable[[], torch.dtype | None] | None = None,
enable_pdl: bool = False,
):
super().__init__(
top_k=top_k,
Expand All @@ -139,6 +151,7 @@ def __init__(
)
self.renormalize = renormalize
self.scoring_func = scoring_func
self.enable_pdl = enable_pdl

@property
def routing_method_type(self) -> RoutingMethodType:
Expand All @@ -164,6 +177,7 @@ def _compute_routing(
renormalize=self.renormalize,
indices_type=indices_type,
scoring_func=self.scoring_func,
enable_pdl=self.enable_pdl,
)

return topk_weights, topk_ids
Loading