Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions examples/BuddyDeepSeekR1/import-deepseek-r1.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
apply_classic_fusion,
eliminate_transpose,
eliminate_matmul_transpose_reshape,
flash_attention_prefill,
gqa_attention_fusion,
gqa_flash_attention_prefill_fusion,
gqa_attention_decode_fusion,
)
from buddy.compiler.graph.type import DeviceType
from buddy.compiler.graph.operation import *
Expand Down Expand Up @@ -250,12 +250,12 @@
pattern_list_prefill = [
simply_fuse,
apply_classic_fusion,
flash_attention_prefill,
gqa_flash_attention_prefill_fusion,
]
pattern_list_decode = [
simply_fuse,
apply_classic_fusion,
gqa_attention_fusion,
gqa_attention_decode_fusion,
]

graphs_prefill[0].fuse_ops(pattern_list_prefill)
Expand Down
14 changes: 7 additions & 7 deletions frontend/Python/graph/operation.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,7 +869,13 @@ def __init__(self) -> None:
self._op_type = OpType.ElementwiseType


class FlashAttentionForCpuPrefillOp(Op):
class GQAFlashAttentionPrefillFusedOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class GQAAttentionDecodeFusedOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType
Expand Down Expand Up @@ -2520,12 +2526,6 @@ def __init__(self) -> None:
self._op_type = OpType.ReshapeType


class GQAAttentionFusedOp(Op):
def __init__(self) -> None:
super().__init__()
self._op_type = OpType.ElementwiseType


class AsStridedScatterOp(Op):
"""
Scatter into a base tensor via as_strided view.
Expand Down
4 changes: 2 additions & 2 deletions frontend/Python/graph/transform/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
from .fuse_ops import (
simply_fuse,
apply_classic_fusion,
flash_attention_prefill,
gqa_attention_fusion,
gqa_flash_attention_prefill_fusion,
gqa_attention_decode_fusion,
)
from .useless_op_eliminate import maxpool2d_simplify
from .eliminate_weight_transpose import eliminate_transpose
Expand Down
31 changes: 9 additions & 22 deletions frontend/Python/graph/transform/fuse_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@

classicfuse_register = {
"transpose_matmul_fusion": TransposeMatmulFusedOp,
"flash_attention_prefill_fusion": FlashAttentionForCpuPrefillOp,
"gqa_attention_fusion": GQAAttentionFusedOp,
"gqa_flash_attention_prefill_fusion": GQAFlashAttentionPrefillFusedOp,
"gqa_attention_decode_fusion": GQAAttentionDecodeFusedOp,
}

# TODO: classify op type for op fusion
Expand Down Expand Up @@ -140,13 +140,13 @@ def simply_fuse(graph: Graph):
graph.group_map_device = {"subgraph0": device}


def flash_attention_prefill(graph: Graph):
def gqa_flash_attention_prefill_fusion(graph: Graph):
"""
Replace ScaledDotProductFlashAttentionForCpuOp with FlashAttentionForCpuPrefillOp.
Replace ScaledDotProductFlashAttentionForCpuOp with GQAFlashAttentionPrefillFusedOp.
"""
new_op_group = []
device = DeviceType.CPU
replace_attention_op(graph)
gqa_attention_fusion_check(graph, "gqa_flash_attention_prefill_fusion")

for op in graph.body:
if isinstance(op, PlaceholderOp):
Expand All @@ -157,20 +157,7 @@ def flash_attention_prefill(graph: Graph):
graph.group_map_device = {"subgraph0": device}


def replace_attention_op(graph: Graph):
"""
replace ScaledDotProductFlashAttentionForCpuOp with FlashAttentionForCpuPrefillOp.
"""
for op in list(graph.body):
if isinstance(op, ScaledDotProductFlashAttentionForCpuOp):
new_op = classicfuse_register.get(
"flash_attention_prefill_fusion"
)()
new_op.name = "FlashAttentionForCpuPrefillOp"
graph.displace_node(op, new_op)


def gqa_attention_fusion(graph: Graph):
def gqa_attention_decode_fusion(graph: Graph):
"""
Function to fuse GQA Attention operations into one operation and fuse
all operations into one graph.
Expand All @@ -183,7 +170,7 @@ def gqa_attention_fusion(graph: Graph):
"""
new_op_group = []
device = DeviceType.CPU
gqa_attention_fusion_check(graph)
gqa_attention_fusion_check(graph, "gqa_attention_decode_fusion")
for op in graph.body:
if isinstance(op, PlaceholderOp):
continue
Expand All @@ -193,7 +180,7 @@ def gqa_attention_fusion(graph: Graph):
graph.group_map_device = {"subgraph0": device}


def gqa_attention_fusion_check(graph: Graph):
def gqa_attention_fusion_check(graph: Graph, new_op):
for op in graph.body:
# === GQA Attention pattern ===
if isinstance(op, ScaledDotProductFlashAttentionForCpuOp):
Expand Down Expand Up @@ -256,7 +243,7 @@ def gqa_attention_fusion_check(graph: Graph):
v_slice1,
v_slice2,
v_cache_unsqueeze,
"gqa_attention_fusion",
new_op,
)


Expand Down
24 changes: 14 additions & 10 deletions frontend/Python/ops/tosa.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@
RandIntLowOp,
ArgMaxOp,
ScaledDotProductFlashAttentionForCpuOp,
FlashAttentionForCpuPrefillOp,
GQAFlashAttentionPrefillFusedOp,
GQAAttentionDecodeFusedOp,
MatmulOp,
LeOp,
BitwiseAndTensorOp,
Expand Down Expand Up @@ -235,7 +236,6 @@
LocalScalarDenseOp,
ResizeOp,
SplitWithSizesOp,
GQAAttentionFusedOp,
)
from .utils import *

Expand Down Expand Up @@ -3827,11 +3827,11 @@ def scaled_dot_product_flash_attention_for_cpu_op(
return result_reshape_op, log_sumexp


def flash_attention_for_cpu_prefill_op(
node: "FlashAttentionForCpuPrefillOp", symbol_table
def gqa_flash_attention_prefill_fused_op(
node: "GQAFlashAttentionPrefillFusedOp", symbol_table
):
"""
Lower FlashAttentionForCpuPrefillOp into MLIR affine+vector IR.
Lower GQAFlashAttentionPrefillFusedOp into MLIR affine+vector IR.
Returns:
result_tensor: Final attention output tensor
log_sumexp_reshape: Placeholder log-sum-exp (can be reshaped as needed)
Expand Down Expand Up @@ -3936,6 +3936,8 @@ def flash_attention_for_cpu_prefill_op(
h = body_block.add_argument(
ir.IndexType.get(), ir.Location.unknown()
)
c6i = arith.ConstantOp(index, 6, loc=loc).result
hk = arith.DivSIOp(h, c6i, loc=loc).result
# query sequence length block loop
loop_q = affine.AffineParallelOp(
results_=[],
Expand Down Expand Up @@ -4061,7 +4063,7 @@ def flash_attention_for_cpu_prefill_op(
v16_qkv, Q_memref, [b, h, idx_q, k]
)
k_data = vector.LoadOp(
v16_qkv, K_memref, [b, h, idx_k, k]
v16_qkv, K_memref, [b, hk, idx_k, k]
)
new_acc = vector.FMAOp(
q_data.result,
Expand Down Expand Up @@ -4165,7 +4167,7 @@ def flash_attention_for_cpu_prefill_op(
with ir.InsertionPoint(loop_k.body):
k = loop_k.induction_variable
v_data = vector.LoadOp(
v16_qkv, V_memref, [b, h, idx_k, k]
v16_qkv, V_memref, [b, hk, idx_k, k]
)
acc_block_val = vector.LoadOp(
v16, acc_block_memref, [k]
Expand Down Expand Up @@ -11221,7 +11223,9 @@ def diagonal_op(node: DiagonalOp, symbol_table):
return output_tensor


def gqa_attention_fused_op(node: GQAAttentionFusedOp, symbol_table):
def gqa_attention_decode_fused_op(
node: GQAAttentionDecodeFusedOp, symbol_table
):
"""
Import attention kernel (QK^T + softmax + V) from graph to MLIR.
"""
Expand Down Expand Up @@ -11525,7 +11529,8 @@ def gqa_attention_fused_op(node: GQAAttentionFusedOp, symbol_table):
"RandIntLowOp": randint_low_op,
"ArgMaxOp": argmax_op,
"ScaledDotProductFlashAttentionForCpuOp": scaled_dot_product_flash_attention_for_cpu_op,
"FlashAttentionForCpuPrefillOp": flash_attention_for_cpu_prefill_op,
"GQAFlashAttentionPrefillFusedOp": gqa_flash_attention_prefill_fused_op,
"GQAAttentionDecodeFusedOp": gqa_attention_decode_fused_op,
"LeOp": le_op,
"BitwiseAndTensorOp": bitwise_and_tensor_op,
"BitwiseLeftShiftOp": bitwise_left_shift_op,
Expand Down Expand Up @@ -11697,5 +11702,4 @@ def gqa_attention_fused_op(node: GQAAttentionFusedOp, symbol_table):
# FftR2cOp is implemented in linalg.py using DFT matrix multiplication
"LocalScalarDenseOp": local_scalar_dense_op,
"ResizeOp": resize_op,
"GQAAttentionFusedOp": gqa_attention_fused_op,
}
6 changes: 2 additions & 4 deletions tests/Python/test_gqa_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,7 @@
import torch.nn.functional as F
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp
from buddy.compiler.graph.transform import (
gqa_attention_fusion,
)
from buddy.compiler.graph.transform import gqa_attention_decode_fusion

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa
Expand Down Expand Up @@ -53,7 +51,7 @@ def foo(query, k_cache, v_cache, mask, scale):
assert len(graphs) == 1
graph = graphs[0]

pattern_list = [gqa_attention_fusion]
pattern_list = [gqa_attention_decode_fusion]
graphs[0].fuse_ops(pattern_list)

graph.lower_to_top_level_ir()
Expand Down
76 changes: 76 additions & 0 deletions tests/Python/test_gqa_flash_attention.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
# RUN: %PYTHON %s 2>&1 | FileCheck %s

from tabnanny import verbose
import torch
import torch.nn.functional as F
import torch._dynamo as dynamo
from torch._inductor.decomposition import decompositions as inductor_decomp
from buddy.compiler.graph.transform import gqa_flash_attention_prefill_fusion

from buddy.compiler.frontend import DynamoCompiler
from buddy.compiler.ops import tosa


def foo(query, k_cache, v_cache, mask, scale):

k_unsqueeze = torch.unsqueeze(k_cache, 2)
k_slice1 = torch.narrow(k_unsqueeze, 3, 0, k_unsqueeze.size(3))
k_slice2 = torch.narrow(k_slice1, 4, 0, k_slice1.size(4))
k_expanded = k_slice2.expand(1, 2, 6, 1024, 128)
k_clone = k_expanded.clone()
k_view = k_clone.view(1, 12, 1024, 128)

v_unsqueeze = torch.unsqueeze(v_cache, 2)
v_slice1 = torch.narrow(v_unsqueeze, 3, 0, v_unsqueeze.size(3))
v_slice2 = torch.narrow(v_slice1, 4, 0, v_slice1.size(4))
v_expanded = v_slice2.expand(1, 2, 6, 1024, 128)
v_clone = v_expanded.clone()
v_view = v_clone.view(1, 12, 1024, 128)

attn_output = F.scaled_dot_product_attention(
query, k_view, v_view, attn_mask=mask, scale=scale
)

return attn_output


in1 = torch.randn(1, 12, 1, 128) # [Batch, Head, MaxSeq, Dim]
in2 = torch.randn(1, 2, 1024, 128)
in3 = torch.randn(1, 2, 1024, 128) # [Batch, Head, 1, Dim]
in4 = torch.randn(1, 1, 1, 1024)
in5 = 1.0 / (128**0.5)

# Initialize the dynamo compiler.
dynamo_compiler = DynamoCompiler(
primary_registry=tosa.ops_registry,
aot_autograd_decomposition=inductor_decomp,
verbose=False,
)

graphs = dynamo_compiler.importer(foo, in1, in2, in3, in4, in5)
assert len(graphs) == 1
graph = graphs[0]

pattern_list = [gqa_flash_attention_prefill_fusion]
graphs[0].fuse_ops(pattern_list)

graph.lower_to_top_level_ir()
print(graph._imported_module)

# CHECK-LABEL: func.func @forward
# CHECK: %[[C6:.*]] = arith.constant 6 : index
# CHECK: %[[HEAD_IDX:.*]] = arith.divsi %{{.*}}, %[[C6]] : index
# CHECK: scf.for
# CHECK: scf.for
# CHECK: vector.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref<1x12x1x128xf32>, vector<16xf32>
# CHECK: vector.load %{{.*}}[%{{.*}}, %[[HEAD_IDX]], %{{.*}}, %{{.*}}] : memref<1x2x1024x128xf32>, vector<16xf32>
# CHECK: vector.fma
# CHECK: vector.reduction <add>
# CHECK: arith.mulf
# CHECK: arith.addf
# CHECK: arith.cmpf ogt
# CHECK: arith.select
# CHECK: math.exp
# CHECK: vector.fma
# CHECK: arith.divf
# CHECK: return %{{.*}} : tensor<1x12x1x128xf32>
Loading