Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -373,6 +373,39 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
ops.def("gptq_shuffle(Tensor! q_weight, Tensor q_perm, int bit) -> ()");
ops.impl("gptq_shuffle", torch::kCUDA, &gptq_shuffle);

// Dequantization for GGML.
ops.def(
"ggml_dequantize(Tensor W, int type, SymInt m, SymInt n, ScalarType? "
"dtype) -> Tensor");
ops.impl("ggml_dequantize", torch::kCUDA, &ggml_dequantize);

// mmvq kernel for GGML.
ops.def(
"ggml_mul_mat_vec_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
ops.impl("ggml_mul_mat_vec_a8", torch::kCUDA, &ggml_mul_mat_vec_a8);

// mmq kernel for GGML.
ops.def("ggml_mul_mat_a8(Tensor W, Tensor X, int type, SymInt row) -> Tensor");
ops.impl("ggml_mul_mat_a8", torch::kCUDA, &ggml_mul_mat_a8);

// mmq kernel for GGML (MoE).
ops.def(
"ggml_moe_a8(Tensor X, Tensor W, "
"Tensor sorted_token_ids, Tensor expert_ids, "
"Tensor num_tokens_post_padded, int type, "
"SymInt row, SymInt top_k, SymInt tokens) -> Tensor");
ops.impl("ggml_moe_a8", torch::kCUDA, &ggml_moe_a8);

// mmvq kernel for GGML (MoE).
ops.def(
"ggml_moe_a8_vec(Tensor X, Tensor W, "
"Tensor topk_ids, int top_k, "
"int type, SymInt row, SymInt tokens) -> Tensor");
ops.impl("ggml_moe_a8_vec", torch::kCUDA, &ggml_moe_a8_vec);

ops.def("ggml_moe_get_block_size(int type) -> int");
ops.impl("ggml_moe_get_block_size", &ggml_moe_get_block_size);

// ┌---------- Not supported for Metax -----------┐
// Compute FP8 quantized tensor for given scaling factor.
// Supports per-tensor, per-channel, per-token, and arbitrary 2D group
Expand Down