diff --git a/.gitignore b/.gitignore index c2e66af8..44426189 100755 --- a/.gitignore +++ b/.gitignore @@ -20,3 +20,4 @@ id_ed25519.pub *.model .cline_storage *.egg-info +CLAUDE.md diff --git a/aie_kernels/aie2/dual_gemv_silu_mul.cc b/aie_kernels/aie2/dual_gemv_silu_mul.cc new file mode 100644 index 00000000..e28aaae1 --- /dev/null +++ b/aie_kernels/aie2/dual_gemv_silu_mul.cc @@ -0,0 +1,79 @@ +// SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Fused dual-GEMV + SiLU + elementwise multiply kernel for AIE2. +// Same structure as AIE2+ variant but uses LUT-based getTanhBf16. + +#define NOCPP + +#include "../aie_kernel_utils.h" +#include "lut_based_ops.h" + +#include +#include +#include + +static bfloat16 left_buf[1024] __attribute__((aligned(64))); +static bfloat16 right_buf[1024] __attribute__((aligned(64))); + +template +void matvec_vectorized(uint32_t m, + uint32_t k, + const bfloat16 *__restrict a, + const bfloat16 *__restrict b, + bfloat16 *__restrict c) +{ + ::aie::set_rounding(aie::rounding_mode::conv_even); + bfloat16 *c_end = c + m; + const bfloat16 *b_end = b + k; + for (; c < c_end; c++) { + aie::accum acc = aie::zeros(); + AIE_LOOP_MIN_ITERATION_COUNT(2) + for (const bfloat16 *__restrict b_cur = b; b_cur < b_end; b_cur += r, a += r) { + aie::vector a_vec = aie::load_v(a); + aie::vector b_vec = aie::load_v(b_cur); + acc = aie::mac(acc, a_vec, b_vec); + } + *c = static_cast(aie::reduce_add(acc.template to_vector())); + } +} + +extern "C" { + +void dual_gemv_matvec_bf16(uint32_t m, + uint32_t k, + uint32_t row_offset, + const bfloat16 *__restrict a_in, + const bfloat16 *__restrict b_in, + uint32_t phase) +{ + bfloat16 *dst = (phase == 0) ? left_buf : right_buf; + dst += row_offset; + matvec_vectorized<64>(m, k, a_in, b_in, dst); +} + +void dual_gemv_silu_mul_bf16(bfloat16 *__restrict c_out, int32_t m_output) +{ + event0(); + + aie::vector register_0_5 = aie::broadcast(0.5f); + aie::vector register_1 = aie::broadcast(1.0f); + AIE_PREPARE_FOR_PIPELINING + for (int i = 0; i < m_output; i += 16) { + aie::vector left_val = aie::load_v<16>(left_buf + i); + aie::vector right_val = aie::load_v<16>(right_buf + i); + + aie::vector half_x = aie::mul(left_val, register_0_5); + aie::vector tanh_half_x = getTanhBf16(half_x); + auto tanh_half_x_approx = aie::add(tanh_half_x, register_1); + aie::vector sigmoid_approx = aie::mul(tanh_half_x_approx, register_0_5); + auto silu_output = aie::mul(left_val, sigmoid_approx); + + auto fused_output = aie::mul(silu_output.to_vector(), right_val); + aie::store_v(c_out + i, fused_output.to_vector()); + } + + event1(); +} + +} // extern "C" diff --git a/aie_kernels/aie2/silu_mul.cc b/aie_kernels/aie2/silu_mul.cc new file mode 100644 index 00000000..a5a15eb8 --- /dev/null +++ b/aie_kernels/aie2/silu_mul.cc @@ -0,0 +1,61 @@ +// SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "../aie_kernel_utils.h" +#include "lut_based_ops.h" + +#include +#include + +using namespace aie; + +void silu_mul_tanh_approx_bf16(bfloat16 *restrict silu_input, + bfloat16 *restrict mul_input, + bfloat16 *restrict output_vector, + const int32_t vector_size) +{ + event0(); + + auto it_silu_in = aie::begin_restrict_vector<16>((bfloat16 *)silu_input); + auto it_mul_in = aie::begin_restrict_vector<16>((bfloat16 *)mul_input); + auto it_out = aie::begin_restrict_vector<16>((bfloat16 *)output_vector); + + aie::vector register_0_5 = aie::broadcast(0.5f); + aie::vector register_1 = aie::broadcast(1.0f); + AIE_PREPARE_FOR_PIPELINING + AIE_LOOP_MIN_ITERATION_COUNT(64) + for (int i = 0; i < vector_size; i += 16) { + // Load input vectors + aie::vector input = *it_silu_in++; + aie::vector mul_in = *it_mul_in++; + + // Compute SiLU: x * sigmoid(x) where sigmoid(x) = 0.5 * (1 + tanh(x/2)) + aie::vector half_x = aie::mul(input, register_0_5); + aie::vector tanh_half_x = getTanhBf16(half_x); + auto tanh_half_x_approx = aie::add(tanh_half_x, register_1); + aie::vector sigmoid_approx = aie::mul(tanh_half_x_approx, register_0_5); + auto silu_output = aie::mul(input, sigmoid_approx); + + // Fused multiply: silu(input) * mul_input + auto fused_output = aie::mul(silu_output.to_vector(), mul_in); + + // Store output vector + *it_out++ = fused_output.to_vector(); + } + + event1(); + + return; +} + +extern "C" { + +void silu_mul_bf16(bfloat16 *restrict silu_input, + bfloat16 *restrict mul_input, + bfloat16 *restrict output, + int input_size) +{ + silu_mul_tanh_approx_bf16(silu_input, mul_input, output, input_size); +} + +} // extern "C" diff --git a/aie_kernels/aie2p/dual_gemv_silu_mul.cc b/aie_kernels/aie2p/dual_gemv_silu_mul.cc new file mode 100644 index 00000000..178364af --- /dev/null +++ b/aie_kernels/aie2p/dual_gemv_silu_mul.cc @@ -0,0 +1,90 @@ +// SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +// Fused dual-GEMV + SiLU + elementwise multiply kernel for AIE2+. +// +// Computes: output = silu(W1 @ x) * (W2 @ x) +// +// Two entry points called from the NPU design's core body: +// 1. dual_gemv_matvec_bf16: GEMV writing to FIFO buffer c_out + row_offset +// 2. dual_gemv_silu_mul_bf16: reads from static left_buf/right_buf, writes to FIFO c_out +// +// The static buffers are written via scalar stores (from matvec) and read +// via aie::load_v in the silu_mul phase. Aligned to 64 bytes for safe vector access. + +#define NOCPP + +#include "../aie_kernel_utils.h" + +#include +#include +#include + +static bfloat16 left_buf[1024] __attribute__((aligned(64))); +static bfloat16 right_buf[1024] __attribute__((aligned(64))); + +template +void matvec_vectorized(uint32_t m, + uint32_t k, + const bfloat16 *__restrict a, + const bfloat16 *__restrict b, + bfloat16 *__restrict c) +{ + ::aie::set_rounding(aie::rounding_mode::conv_even); + bfloat16 *c_end = c + m; + const bfloat16 *b_end = b + k; + for (; c < c_end; c++) { + aie::accum acc = aie::zeros(); + AIE_LOOP_MIN_ITERATION_COUNT(2) + for (const bfloat16 *__restrict b_cur = b; b_cur < b_end; b_cur += r, a += r) { + aie::vector a_vec = aie::load_v(a); + aie::vector b_vec = aie::load_v(b_cur); + acc = aie::mac(acc, a_vec, b_vec); + } + *c = static_cast(aie::reduce_add(acc.template to_vector())); + } +} + +extern "C" { + +// Phase 1 & 2: GEMV writing to a static buffer (left_buf or right_buf) +// phase=0 writes to left_buf, phase=1 writes to right_buf +void dual_gemv_matvec_bf16(uint32_t m, + uint32_t k, + uint32_t row_offset, + const bfloat16 *__restrict a_in, + const bfloat16 *__restrict b_in, + uint32_t phase) +{ + bfloat16 *dst = (phase == 0) ? left_buf : right_buf; + dst += row_offset; + matvec_vectorized<64>(m, k, a_in, b_in, dst); +} + +// Phase 3: silu(left_buf) * right_buf -> c_out (FIFO buffer) +void dual_gemv_silu_mul_bf16(bfloat16 *__restrict c_out, int32_t m_output) +{ + event0(); + + aie::vector register_0_5 = aie::broadcast(0.5f); + aie::vector register_1 = aie::broadcast(1.0f); + AIE_PREPARE_FOR_PIPELINING + for (int i = 0; i < m_output; i += 16) { + aie::vector left_val = aie::load_v<16>(left_buf + i); + aie::vector right_val = aie::load_v<16>(right_buf + i); + + // SiLU(x) = x * sigmoid(x) = x * 0.5 * (1 + tanh(x/2)) + auto half_x = aie::mul(left_val, register_0_5); + auto tanh_half_x = aie::tanh(half_x.to_vector()); + auto tanh_half_x_approx = aie::add(tanh_half_x, register_1); + aie::vector sigmoid_approx = aie::mul(tanh_half_x_approx, register_0_5); + auto silu_output = aie::mul(left_val, sigmoid_approx); + + auto fused_output = aie::mul(silu_output.to_vector(), right_val); + aie::store_v(c_out + i, fused_output.to_vector()); + } + + event1(); +} + +} // extern "C" diff --git a/aie_kernels/aie2p/silu_mul.cc b/aie_kernels/aie2p/silu_mul.cc new file mode 100644 index 00000000..51fd05a0 --- /dev/null +++ b/aie_kernels/aie2p/silu_mul.cc @@ -0,0 +1,60 @@ +// SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +// SPDX-License-Identifier: Apache-2.0 + +#include "../aie_kernel_utils.h" + +#include +#include + +using namespace aie; + +void silu_mul_tanh_approx_bf16(bfloat16 *restrict silu_input, + bfloat16 *restrict mul_input, + bfloat16 *restrict output_vector, + const int32_t vector_size) +{ + event0(); + + auto it_silu_in = aie::begin_restrict_vector<16>((bfloat16 *)silu_input); + auto it_mul_in = aie::begin_restrict_vector<16>((bfloat16 *)mul_input); + auto it_out = aie::begin_restrict_vector<16>((bfloat16 *)output_vector); + + aie::vector register_0_5 = aie::broadcast(0.5f); + aie::vector register_1 = aie::broadcast(1.0f); + AIE_PREPARE_FOR_PIPELINING + AIE_LOOP_MIN_ITERATION_COUNT(64) + for (int i = 0; i < vector_size; i += 16) { + // Load input vectors + aie::vector input = *it_silu_in++; + aie::vector mul_in = *it_mul_in++; + + // Compute SiLU: x * sigmoid(x) where sigmoid(x) = 0.5 * (1 + tanh(x/2)) + auto half_x = aie::mul(input, register_0_5); + auto tanh_half_x = aie::tanh(half_x.to_vector()); + auto tanh_half_x_approx = aie::add(tanh_half_x, register_1); + aie::vector sigmoid_approx = aie::mul(tanh_half_x_approx, register_0_5); + auto silu_output = aie::mul(input, sigmoid_approx); + + // Fused multiply: silu(input) * mul_input + auto fused_output = aie::mul(silu_output.to_vector(), mul_in); + + // Store output vector + *it_out++ = fused_output.to_vector(); + } + + event1(); + + return; +} + +extern "C" { + +void silu_mul_bf16(bfloat16 *restrict silu_input, + bfloat16 *restrict mul_input, + bfloat16 *restrict output, + int input_size) +{ + silu_mul_tanh_approx_bf16(silu_input, mul_input, output, input_size); +} + +} // extern "C" diff --git a/iron/operators/__init__.py b/iron/operators/__init__.py index fc203892..98cf0a1e 100644 --- a/iron/operators/__init__.py +++ b/iron/operators/__init__.py @@ -3,6 +3,7 @@ from .axpy.op import AIEAXPY from .dequant.op import AIEDequant +from .dual_gemv_silu_mul.op import AIEDualGEMVSiLUMul from .elementwise_add.op import AIEElementwiseAdd from .elementwise_mul.op import AIEElementwiseMul from .gelu.op import AIEGELU @@ -17,6 +18,7 @@ from .rope.op import AIERope from .sigmoid.op import AIESigmoid from .silu.op import AIESiLU +from .silu_mul.op import AIESiLUMul from .softmax.op import AIESoftmax from .swiglu_decode.op import AIESwiGLUDecode from .swiglu_prefill.op import AIESwiGLUPrefill diff --git a/iron/operators/dual_gemv_silu_mul/design.py b/iron/operators/dual_gemv_silu_mul/design.py new file mode 100644 index 00000000..b2f5de4a --- /dev/null +++ b/iron/operators/dual_gemv_silu_mul/design.py @@ -0,0 +1,180 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import numpy as np +from pathlib import Path +from ml_dtypes import bfloat16 +import argparse + +import aie.dialects.index as index +from aie.dialects.aie import * +from aie.dialects.aiex import * +from aie.helpers.dialects.scf import _for as range_ +from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker +from aie.iron.placers import SequentialPlacer +from aie.iron.device import NPU1, NPU2 + +""" +Dual matrix-vector + SiLU + elementwise multiply design. + +Computes: output = silu(W1 @ x) * (W2 @ x) + +W1 and W2 rows are pre-interleaved in DDR by the operator (op.py). +GEMV phases write to kernel-internal static buffers (left_buf, right_buf) +controlled by a phase parameter. The silu_mul phase reads from those +buffers and writes the result to the output C FIFO. + +Each AIE core: + 1. Acquires vector x (held in L1 for both GEMV passes) + 2. Consumes W1 rows from A FIFO, writes dot products to left_buf (phase=0) + 3. Consumes W2 rows from A FIFO, writes dot products to right_buf (phase=1) + 4. Computes silu(left_buf) * right_buf -> C FIFO output +""" + + +def my_dual_gemv_silu_mul(dev, cols, M, K, m_input, m_output=None): + if m_output is None: + m_output = m_input + + assert m_output % m_input == 0 and m_output >= m_input + assert m_output <= M // cols + assert (M // cols) % m_output == 0 + assert m_input <= M // cols + assert (M // cols) % m_input == 0 + + dtype_in = np.dtype[bfloat16] + dtype_out = np.dtype[bfloat16] + + assert M % cols == 0 + + dev_ty = NPU1() if dev == "npu" else NPU2() + + # L1 tile types + L1_A_ty = np.ndarray[(m_input, K), dtype_in] + L1_B_ty = np.ndarray[(K,), dtype_in] + L1_C_ty = np.ndarray[(m_output,), dtype_out] + + # L3 (DDR) buffer types + L3_W_ty = np.ndarray[(2 * M, K), dtype_in] + L3_B_ty = np.ndarray[(K,), dtype_in] + L3_C_ty = np.ndarray[(M,), dtype_out] + + # GEMV: writes to left_buf (phase=0) or right_buf (phase=1) + matvec = Kernel( + "dual_gemv_matvec_bf16", + "dual_gemv_silu_mul.o", + [np.int32, np.int32, np.int32, L1_A_ty, L1_B_ty, np.int32], + ) + + # SiLU+Mul: reads from static left_buf/right_buf, writes to C FIFO + silu_mul_fn = Kernel( + "dual_gemv_silu_mul_bf16", + "dual_gemv_silu_mul.o", + [L1_C_ty, np.int32], + ) + + # ObjectFIFOs: 2 inputs + 1 output = fits AIE DMA channel limits + A_fifos = [ObjectFifo(L1_A_ty, name=f"A_{i}", depth=2) for i in range(cols)] + B_fifos = [ObjectFifo(L1_B_ty, name=f"B_{i}", depth=1) for i in range(cols)] + C_fifos = [ObjectFifo(L1_C_ty, name=f"C_{i}", depth=2) for i in range(cols)] + + def core_body(A_fifo, B_fifo, C_fifo, matvec_fn, silu_mul): + for _ in range_(0xFFFFFFFF): + b = B_fifo.acquire(1) + for i_idx in range_(M // m_output // cols): + # Phase 1: W1 rows -> left_buf (phase=0) + for j_idx in range_(m_output // m_input): + j_i32 = index.casts(T.i32(), j_idx) + row_offset = j_i32 * m_input + a = A_fifo.acquire(1) + matvec_fn(m_input, K, row_offset, a, b, 0) + A_fifo.release(1) + # Phase 2: W2 rows -> right_buf (phase=1) + for j_idx in range_(m_output // m_input): + j_i32 = index.casts(T.i32(), j_idx) + row_offset = j_i32 * m_input + a = A_fifo.acquire(1) + matvec_fn(m_input, K, row_offset, a, b, 1) + A_fifo.release(1) + # Phase 3: silu(left_buf) * right_buf -> output + c = C_fifo.acquire(1) + silu_mul(c, m_output) + C_fifo.release(1) + B_fifo.release(1) + + workers = [ + Worker( + core_body, + [ + A_fifos[i].cons(), + B_fifos[i].cons(), + C_fifos[i].prod(), + matvec, + silu_mul_fn, + ], + ) + for i in range(cols) + ] + + # Interleaved weight distribution per column + rows_per_col = M // cols + A_taps = [ + TensorAccessPattern( + tensor_dims=(2 * M, K), + offset=col * 2 * rows_per_col * K, + sizes=[1, 1, 1, 2 * rows_per_col * K], + strides=[0, 0, 0, 1], + ) + for col in range(cols) + ] + + # Output collection + C_taps = [ + TensorAccessPattern( + tensor_dims=(1, M), + offset=col * (M // cols), + sizes=[1, 1, 1, (M // cols)], + strides=[0, 0, 0, 1], + ) + for col in range(cols) + ] + + rt = Runtime() + with rt.sequence(L3_W_ty, L3_B_ty, L3_C_ty) as (W, B, C): + rt.start(*workers) + tg = rt.task_group() + for i in range(cols): + rt.fill(A_fifos[i].prod(), W, A_taps[i], task_group=tg) + rt.fill(B_fifos[i].prod(), B, task_group=tg) + for i in range(cols): + rt.drain(C_fifos[i].cons(), C, C_taps[i], task_group=tg, wait=True) + rt.finish_task_group(tg) + + return Program(dev_ty, rt).resolve_program(SequentialPlacer()) + + +if __name__ == "__main__": + argparser = argparse.ArgumentParser( + prog="AIE Dual GEMV + SiLU + Mul Design", + ) + argparser.add_argument("--dev", type=str, choices=["npu", "npu2"], default="npu") + argparser.add_argument("-M", type=int, required=True) + argparser.add_argument("-K", type=int, required=True) + argparser.add_argument("-m", type=int, required=True, dest="m_input") + argparser.add_argument("--m-output", type=int, default=None, dest="m_output") + argparser.add_argument("--cols", type=int, required=True) + argparser.add_argument( + "--output-file-path", + "-o", + type=str, + help="Output file path for the generated MLIR module", + ) + args = argparser.parse_args() + module = my_dual_gemv_silu_mul( + args.dev, args.cols, args.M, args.K, args.m_input, args.m_output + ) + + output_file_path = Path(args.output_file_path) + + with open(output_file_path, "w") as f: + f.write(str(module)) diff --git a/iron/operators/dual_gemv_silu_mul/op.py b/iron/operators/dual_gemv_silu_mul/op.py new file mode 100644 index 00000000..64834181 --- /dev/null +++ b/iron/operators/dual_gemv_silu_mul/op.py @@ -0,0 +1,159 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import numpy as np +from ml_dtypes import bfloat16 +from pathlib import Path + +from iron.common import ( + AIEOperatorBase, + AIEOperatorConstraintError, + XclbinArtifact, + InstsBinArtifact, + KernelObjectArtifact, + SourceArtifact, + PythonGeneratedMLIRArtifact, +) +from iron.common.utils import torch_to_numpy + + +def interleave_weights(W1, W2, rows_per_col, cols): + """Interleave W1 and W2 rows per-column for the fused DMA pattern. + + Output layout: [W1_col0_rows, W2_col0_rows, W1_col1_rows, W2_col1_rows, ...] + + This ensures that when the DMA streams data to each column's A FIFO, + the W1 rows arrive first followed by W2 rows, matching the core body's + consumption order. + """ + M = W1.shape[0] + K = W1.shape[1] + result = torch.empty(2 * M, K, dtype=W1.dtype) + for col in range(cols): + start = col * rows_per_col + end = start + rows_per_col + out_start = col * 2 * rows_per_col + result[out_start : out_start + rows_per_col] = W1[start:end] + result[out_start + rows_per_col : out_start + 2 * rows_per_col] = W2[start:end] + return result + + +class AIEDualGEMVSiLUMul(AIEOperatorBase): + """AIE-accelerated fused dual-GEMV + SiLU + elementwise multiply. + + Computes: output = silu(W1 @ x) * (W2 @ x) + + Fuses three operations into a single NPU design: + - Two matrix-vector multiplications sharing the same input vector + - SiLU activation on the first GEMV result + - Elementwise multiplication of SiLU output with second GEMV result + + The intermediate vectors (left, right) never touch DRAM. + W1 and W2 are pre-interleaved in DDR for DMA-compatible streaming. + """ + + def __init__( + self, + M, + K, + num_aie_columns=4, + tile_size_input=4, + tile_size_output=None, + context=None, + ): + if tile_size_output is None: + tile_size_output = M // num_aie_columns + assert tile_size_output % tile_size_input == 0 + assert tile_size_output >= tile_size_input + self.M = M + self.K = K + self.num_aie_columns = num_aie_columns + self.tile_size_input = tile_size_input + self.tile_size_output = tile_size_output + + self.xclbin_artifact = None + self.insts_artifact = None + + AIEOperatorBase.__init__(self, context=context) + + def get_artifacts(self, prefix="dual_gemv_silu_mul_"): + operator_dir = Path(__file__).parent + file_name_base = ( + f"{prefix}{self.M}x{self.K}_{self.tile_size_input}tsi_" + f"{self.tile_size_output}tso_{self.num_aie_columns}col" + ) + + mlir_artifact = PythonGeneratedMLIRArtifact.new( + f"{file_name_base}.mlir", + import_path=operator_dir / "design.py", + callback_fn="my_dual_gemv_silu_mul", + callback_args=[ + self.context.device_manager.device_type, + self.num_aie_columns, + self.M, + self.K, + self.tile_size_input, + self.tile_size_output, + ], + ) + + xclbin_artifact = XclbinArtifact.new( + f"{file_name_base}.xclbin", + depends=[ + mlir_artifact, + KernelObjectArtifact.new( + "dual_gemv_silu_mul.o", + depends=[ + SourceArtifact.new( + self.context.base_dir + / "aie_kernels" + / "aie2p" + / "dual_gemv_silu_mul.cc" + ) + ], + ), + ], + ) + + insts_artifact = InstsBinArtifact.new( + f"{file_name_base}.bin", depends=[mlir_artifact] + ) + + return xclbin_artifact, insts_artifact + + def set_up_artifacts(self): + xclbin_artifact, insts_artifact = self.get_artifacts() + self.xclbin_artifact = xclbin_artifact + self.insts_artifact = insts_artifact + self.add_artifacts([xclbin_artifact, insts_artifact]) + + def set_up_runtime(self): + # The design expects a single interleaved weight buffer (2*M*K) + self.add_buffer("weights_interleaved", 2 * self.M * self.K) + self.add_buffer("vector", self.K) + self.add_buffer("output", self.M) + self.add_kernel( + "dual_gemv_silu_mul", + self.xclbin_artifact, + self.xclbin_artifact.kernel_name, + self.insts_artifact, + ) + self.add_to_runlist( + "dual_gemv_silu_mul", "weights_interleaved", "vector", "output" + ) + + def forward(self, vector, matrix1=None, matrix2=None): + """Forward pass: computes silu(matrix1 @ vector) * (matrix2 @ vector)""" + vector = vector.reshape(*vector.shape[-1:]) + + if matrix1 is not None and matrix2 is not None: + rows_per_col = self.M // self.num_aie_columns + w_interleaved = interleave_weights( + matrix1, matrix2, rows_per_col, self.num_aie_columns + ) + self.write_buffer("weights_interleaved", w_interleaved) + self.write_buffer("vector", vector) + self.run_runlist() + result = self.read_buffer_as_torch("output", (self.M,)) + return result diff --git a/iron/operators/dual_gemv_silu_mul/reference.py b/iron/operators/dual_gemv_silu_mul/reference.py new file mode 100644 index 00000000..b50be353 --- /dev/null +++ b/iron/operators/dual_gemv_silu_mul/reference.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch + + +def generate_golden_reference(M=2048, K=2048, seed=42): + """Generate golden reference for dual-GEMV + SiLU + Mul. + + Computes: output = silu(W1 @ x) * (W2 @ x) + + Returns dict with W1, W2, x, and output tensors. + """ + torch.manual_seed(seed) + val_range = 4 + W1 = torch.randn(M, K, dtype=torch.bfloat16) * val_range + W2 = torch.randn(M, K, dtype=torch.bfloat16) * val_range + x = torch.randn(K, dtype=torch.bfloat16) * val_range + + left = W1 @ x + right = W2 @ x + output = torch.nn.functional.silu(left) * right + + return {"W1": W1, "W2": W2, "x": x, "output": output} diff --git a/iron/operators/dual_gemv_silu_mul/test.py b/iron/operators/dual_gemv_silu_mul/test.py new file mode 100644 index 00000000..f5b8e4f2 --- /dev/null +++ b/iron/operators/dual_gemv_silu_mul/test.py @@ -0,0 +1,80 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import pytest + +from iron.operators.dual_gemv_silu_mul.op import AIEDualGEMVSiLUMul, interleave_weights +from iron.operators.dual_gemv_silu_mul.reference import generate_golden_reference +from iron.common.test_utils import run_test + + +def generate_test_params(extensive=False): + params = [ + # (M, K, num_aie_columns, tile_size_input, tile_size_output) + (2048, 2048, 4, 4, 512), + ] + if extensive: + params += [ + (8192, 2048, 4, 4, 2048), + ] + names = [ + f"dual_gemv_silu_mul_{M}x{K}_{tsi}tsi_{tso}tso_{cols}col" + for M, K, cols, tsi, tso in params + ] + return params, names + + +regular_params, regular_names = generate_test_params(extensive=False) +extensive_params, extensive_names = generate_test_params(extensive=True) + +all_params = [ + pytest.param(*params, id=name) + for params, name in zip(regular_params, regular_names) +] + [ + pytest.param(*params, marks=pytest.mark.extensive, id=name) + for params, name in zip(extensive_params, extensive_names) +] + + +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", +) +@pytest.mark.parametrize( + "M,K,num_aie_columns,tile_size_input,tile_size_output", + all_params, +) +def test_dual_gemv_silu_mul( + M, K, num_aie_columns, tile_size_input, tile_size_output, aie_context +): + golden_ref = generate_golden_reference(M=M, K=K) + + operator = AIEDualGEMVSiLUMul( + M=M, + K=K, + num_aie_columns=num_aie_columns, + tile_size_input=tile_size_input, + tile_size_output=tile_size_output, + context=aie_context, + ) + + rows_per_col = M // num_aie_columns + w_interleaved = interleave_weights( + golden_ref["W1"], golden_ref["W2"], rows_per_col, num_aie_columns + ) + + input_buffers = { + "weights_interleaved": w_interleaved.flatten(), + "vector": golden_ref["x"], + } + output_buffers = {"output": golden_ref["output"]} + + errors, latency_us, bandwidth_gbps = run_test( + operator, input_buffers, output_buffers, rel_tol=0.07, abs_tol=1.0 + ) + + print(f"\nLatency (us): {latency_us:.1f}") + print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s\n") + + assert not errors, f"Test failed with errors: {errors}" diff --git a/iron/operators/silu_mul/design.py b/iron/operators/silu_mul/design.py new file mode 100644 index 00000000..3f3244a6 --- /dev/null +++ b/iron/operators/silu_mul/design.py @@ -0,0 +1,190 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +from ml_dtypes import bfloat16 +from pathlib import Path +import numpy as np +import argparse +import sys + +from aie.iron import Kernel, ObjectFifo, Program, Runtime, Worker +from aie.iron.placers import SequentialPlacer +from aie.iron.device import NPU1, NPU2 +from aie.helpers.taplib.tap import TensorAccessPattern +from aie.iron.controlflow import range_ + + +def my_silu_mul(dev, num_elements, num_columns, num_channels, tile_size, trace_size): + per_tile_elements = 4096 if tile_size > 4096 else tile_size + n = per_tile_elements * num_columns + if num_elements % n != 0: + raise ValueError( + f"Number of elements ({num_elements}) must be a multiple of {n}." + ) + N_div_n = num_elements // n + chunk = num_elements // num_columns + dtype = bfloat16 + + # Define tensor types + tensor_ty = np.ndarray[(num_elements,), np.dtype[dtype]] + tile_ty = np.ndarray[(per_tile_elements,), np.dtype[dtype]] + + # AIE-array data movement with object fifos (one per column) + of_in1s = [ObjectFifo(tile_ty, name=f"in1_{i}") for i in range(num_columns)] + of_in2s = [ObjectFifo(tile_ty, name=f"in2_{i}") for i in range(num_columns)] + of_outs = [ObjectFifo(tile_ty, name=f"out_{i}") for i in range(num_columns)] + + # AIE Core Function declaration + silu_mul_bf16 = Kernel( + "silu_mul_bf16", "silu_mul.o", [tile_ty, tile_ty, tile_ty, np.int32] + ) + + # Define a task that will run on a compute tile + def core_body(of_in1, of_in2, of_out, silu_mul_fn): + for _ in range_(N_div_n): + elem_in1 = of_in1.acquire(1) + elem_in2 = of_in2.acquire(1) + elem_out = of_out.acquire(1) + silu_mul_fn(elem_in1, elem_in2, elem_out, per_tile_elements) + of_in1.release(1) + of_in2.release(1) + of_out.release(1) + + # Create a worker to run the task on a compute tile (one per column) + my_workers = [ + Worker( + core_body, + [ + of_in1s[i].cons(), + of_in2s[i].cons(), + of_outs[i].prod(), + silu_mul_bf16, + ], + ) + for i in range(num_columns) + ] + + # Create a TensorAccessPattern for each column + taps = [ + TensorAccessPattern( + (1, num_elements), + chunk * i, + [1, 1, 1, chunk], + [0, 0, 0, 1], + ) + for i in range(num_columns) + ] + + # Runtime operations to move data to/from the AIE-array + rt = Runtime() + with rt.sequence(tensor_ty, tensor_ty, tensor_ty) as (A, B, C): + rt.start(*my_workers) + + # Initialize a group for parallel drain tasks, with fill resources free'd when drains complete. + tg = rt.task_group() + + # Fill the input objectFIFOs with data + for i in range(num_columns): + rt.fill( + of_in1s[i].prod(), + A, + taps[i], + task_group=tg, + ) + rt.fill( + of_in2s[i].prod(), + B, + taps[i], + task_group=tg, + ) + # Drain the output objectFIFOs with data + for i in range(num_columns): + rt.drain( + of_outs[i].cons(), + C, + taps[i], + wait=True, + task_group=tg, + ) + rt.finish_task_group(tg) + + # Place program components and generate an MLIR module + return Program(dev, rt).resolve_program(SequentialPlacer()) + + +if __name__ == "__main__": + + def str_to_device(device: str): + if device == "npu": + return NPU1() + elif device == "npu2": + return NPU2() + else: + raise ValueError(f"Device name {device} is unknown.") + + p = argparse.ArgumentParser() + p.add_argument( + "-d", + "--dev", + required=True, + dest="device", + help="AIE Device", + type=str_to_device, + ) + p.add_argument("-l", "--length", required=True, dest="length", help="Transfer size") + p.add_argument( + "-co", "--columns", required=True, dest="cols", help="Number of columns" + ) + p.add_argument( + "-ch", "--channels", required=True, dest="chans", help="Number of channels" + ) + p.add_argument( + "-ts", + "--tile-size", + required=False, + dest="tile_size", + default="1024", + help="Tile size (elements per tile)", + ) + p.add_argument( + "-t", "--trace-size", required=True, dest="trace_size", help="Trace size" + ) + p.add_argument( + "--output-file-path", + "-o", + type=str, + help="Output file path for the generated MLIR module", + ) + + opts = p.parse_args(sys.argv[1:]) + + length = int(opts.length) + columns = int(opts.cols) + dev = opts.device + + if isinstance(dev, NPU1) and columns > 4: + raise ValueError("[ERROR] NPU device cannot allocate more than 4 columns") + elif isinstance(dev, NPU2) and columns > 8: + raise ValueError("[ERROR] NPU2 device cannot allocate more than 8 columns") + + channels = int(opts.chans) + if channels < 1 or channels > 2: + raise ValueError("Number of channels must be 1 or 2") + tile_size = int(opts.tile_size) + if length % (tile_size * columns) != 0: + print( + "transfer size (" + + str(length) + + ") must be a multiple of " + + str(tile_size * columns) + + " (tile_size * columns)" + ) + raise ValueError + trace_size = int(opts.trace_size) if opts.trace_size is not None else 0 + + module = my_silu_mul(dev, length, columns, channels, tile_size, trace_size) + + output_file_path = Path(opts.output_file_path) + + with open(output_file_path, "w") as f: + f.write(str(module)) diff --git a/iron/operators/silu_mul/op.py b/iron/operators/silu_mul/op.py new file mode 100644 index 00000000..4bbc6402 --- /dev/null +++ b/iron/operators/silu_mul/op.py @@ -0,0 +1,166 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +import numpy as np +from ml_dtypes import bfloat16 +from pathlib import Path + +from iron.common import ( + AIEOperatorBase, + AIEOperatorConstraintError, + XclbinArtifact, + InstsBinArtifact, + KernelObjectArtifact, + KernelArchiveArtifact, + SourceArtifact, + PythonGeneratedMLIRArtifact, +) + + +class AIESiLUMul(AIEOperatorBase): + """AIE-accelerated fused SiLU activation + element-wise multiplication""" + + def __init__( + self, size, num_aie_columns, num_channels, tile_size, trace_size=0, context=None + ): + max_multiple = num_aie_columns * tile_size + padded_size = ((size + max_multiple - 1) // max_multiple) * max_multiple + self.orig_size = size + self.size = padded_size + self.tile_size = tile_size + self.num_aie_columns = num_aie_columns + self.num_channels = num_channels + self.trace_size = trace_size + + total_shimdma_channels = self.num_aie_columns * self.num_channels + assert total_shimdma_channels <= 16, "Conservative ShimDMA limit" + + self.xclbin_artifact = None + self.insts_artifact = None + + AIEOperatorBase.__init__(self, context=context) + + def get_artifacts(self, prefix="silu_mul_"): + operator_dir = Path(__file__).parent + file_name_base = f"{prefix}{self.num_aie_columns}c_{self.num_channels}ch_{self.size}_{self.tile_size}t" + + mlir_artifact = PythonGeneratedMLIRArtifact.new( + f"{file_name_base}.mlir", + import_path=operator_dir / "design.py", + callback_fn="my_silu_mul", + callback_args=[ + self.context.device_manager.device_type, + self.size, + self.num_aie_columns, + self.num_channels, + self.tile_size, + self.trace_size, + ], + ) + + xclbin_artifact = XclbinArtifact.new( + f"{file_name_base}.xclbin", + depends=[ + mlir_artifact, + KernelObjectArtifact.new( + "silu_mul.o", + depends=[ + SourceArtifact.new( + self.context.base_dir + / "aie_kernels" + / "aie2p" + / "silu_mul.cc" + ) + ], + ), + ], + ) + + insts_artifact = InstsBinArtifact.new( + f"{file_name_base}.bin", depends=[mlir_artifact] + ) + + return xclbin_artifact, insts_artifact + + def set_up_artifacts(self): + xclbin_artifact, insts_artifact = self.get_artifacts() + + self.xclbin_artifact = xclbin_artifact + self.insts_artifact = insts_artifact + + artifacts = [xclbin_artifact, insts_artifact] + self.add_artifacts(artifacts) + + def set_up_runtime(self): + self.add_buffer("input1", self.size) + self.add_buffer("input2", self.size) + self.add_buffer("output", self.size) + self.add_kernel( + "silu_mul", + self.xclbin_artifact, + self.xclbin_artifact.kernel_name, + self.insts_artifact, + ) + self.add_to_runlist("silu_mul", "input1", "input2", "output") + + def forward(self, x, y): + """Forward pass for fused SiLU(x) * y""" + applicable = ( + len(x.shape) >= 1 + and len(y.shape) >= 1 + and x.shape[-1] <= self.size + and y.shape[-1] <= self.size + and x.numel() <= self.size + and y.numel() <= self.size + and x.numel() == y.numel() + and x.shape == y.shape + ) + if not applicable: + raise AIEOperatorConstraintError("AIESiLUMul: incompatible tensor shape(s)") + + # Always flatten to [batch, orig_size] + original_shape = x.shape + batch = x.shape[0] if x.dim() > 1 else 1 + x_flat = x.reshape(batch, -1) + y_flat = y.reshape(batch, -1) + + pad_len = self.size - x_flat.shape[1] + if pad_len > 0: + x_flat = torch.nn.functional.pad(x_flat, (0, pad_len)) + y_flat = torch.nn.functional.pad(y_flat, (0, pad_len)) + + out = self._execute_aie_operation(x_flat, y_flat) + + # Remove padding if added + numel = np.prod(original_shape) + if pad_len > 0: + out = out.reshape(-1)[..., :numel] + # Restore original shape + out = out.reshape(*original_shape) + + return out + + def _execute_aie_operation(self, x, y): + """Execute fused SiLU + multiply operation on AIE hardware""" + # x, y are [batch, size] + batch = x.shape[0] if x.dim() > 1 else 1 + + # Flatten inputs for AIE processing + x_flat = x.view(-1) + y_flat = y.view(-1) + + # Verify size matches expected + if len(x_flat) != self.size or len(y_flat) != self.size: + raise AIEOperatorConstraintError( + f"Input size x={len(x_flat)}, y={len(y_flat)} doesn't match configured size {self.size}" + ) + + self.write_buffer("input1", x_flat) + self.write_buffer("input2", y_flat) + test_pattern = np.zeros(len(x_flat), dtype=bfloat16) + self.write_buffer("output", test_pattern) + self.run_runlist() + result = self.read_buffer_as_torch("output", shape=x_flat.shape, dtype=bfloat16) + + return result diff --git a/iron/operators/silu_mul/reference.py b/iron/operators/silu_mul/reference.py new file mode 100644 index 00000000..dbd5b11b --- /dev/null +++ b/iron/operators/silu_mul/reference.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import torch +from iron.common.utils import torch_dtype_map + + +def generate_golden_reference(input_length: int, dtype="bf16", seed=42): + torch.manual_seed(seed) + val_range = 4 + dtype_torch = torch_dtype_map[dtype] + input_a = torch.rand(input_length, dtype=dtype_torch) * val_range + input_b = torch.rand(input_length, dtype=dtype_torch) * val_range + output = torch.nn.functional.silu(input_a) * input_b + return {"A": input_a, "B": input_b, "C": output} diff --git a/iron/operators/silu_mul/test.py b/iron/operators/silu_mul/test.py new file mode 100644 index 00000000..c9d5ded8 --- /dev/null +++ b/iron/operators/silu_mul/test.py @@ -0,0 +1,78 @@ +#!/usr/bin/env python3 +# SPDX-FileCopyrightText: Copyright (C) 2025 Advanced Micro Devices, Inc. All rights reserved. +# SPDX-License-Identifier: Apache-2.0 + +import sys +import pytest +from pathlib import Path + + +from iron.operators.silu_mul.op import AIESiLUMul +from iron.operators.silu_mul.reference import generate_golden_reference +from iron.common.test_utils import run_test + + +def generate_test_params(extensive=False): + max_aie_columns = 8 + num_channels = 2 + input_lengths = [2048] if not extensive else [1024, 4096, 8192] + + params = [] + names = [] + for input_length in input_lengths: + for num_aie_columns in range(1, max_aie_columns + 1): + tile_size = input_length // num_aie_columns + if tile_size > 4096: + tile_size = 4096 + if tile_size * num_aie_columns != input_length: + continue + names.append( + f"silu_mul_{num_aie_columns}_cols_{num_channels}_channels_{input_length}_tile_{tile_size}" + ) + params.append((input_length, num_aie_columns, num_channels, tile_size)) + return params, names + + +regular_params, regular_names = generate_test_params(extensive=False) +extensive_params, extensive_names = generate_test_params(extensive=True) + +# Combine params with marks - extensive params get pytest.mark.extensive +all_params = [ + pytest.param(*params, id=name) + for params, name in zip(regular_params, regular_names) +] + [ + pytest.param(*params, marks=pytest.mark.extensive, id=name) + for params, name in zip(extensive_params, extensive_names) +] + + +@pytest.mark.metrics( + Latency=r"Latency \(us\): (?P[\d\.]+)", + Bandwidth=r"Effective Bandwidth: (?P[\d\.e\+-]+) GB/s", +) +@pytest.mark.parametrize( + "input_length,num_aie_columns,num_channels,tile_size", + all_params, +) +def test_silu_mul(input_length, num_aie_columns, num_channels, tile_size, aie_context): + golden_ref = generate_golden_reference(input_length=input_length) + + operator = AIESiLUMul( + size=input_length, + num_aie_columns=num_aie_columns, + num_channels=num_channels, + tile_size=tile_size, + context=aie_context, + ) + + input_buffers = {"input1": golden_ref["A"], "input2": golden_ref["B"]} + output_buffers = {"output": golden_ref["C"]} + + errors, latency_us, bandwidth_gbps = run_test( + operator, input_buffers, output_buffers, rel_tol=0.04, abs_tol=1e-6 + ) + + print(f"\nLatency (us): {latency_us:.1f}") + print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s\n") + + assert not errors, f"Test failed with errors: {errors}" diff --git a/iron/operators/swiglu_decode/op.py b/iron/operators/swiglu_decode/op.py index 869493c9..00f15640 100644 --- a/iron/operators/swiglu_decode/op.py +++ b/iron/operators/swiglu_decode/op.py @@ -15,9 +15,8 @@ SourceArtifact, PythonGeneratedMLIRArtifact, ) +from iron.operators.dual_gemv_silu_mul.op import AIEDualGEMVSiLUMul, interleave_weights from iron.operators.gemv.op import AIEGEMV -from iron.operators.silu.op import AIESiLU -from iron.operators.elementwise_mul.op import AIEElementwiseMul from iron.common.utils import torch_to_numpy @@ -34,12 +33,8 @@ def __init__(self, embedding_dim, hidden_dim, prio_accuracy=False, context=None) # Artifacts created by set_up_artifacts() self.combined_xclbin = None - self.gemv_1_xclbin = None - self.gemv_1_insts = None - self.silu_xclbin = None - self.silu_insts = None - self.eltwise_mul_xclbin = None - self.eltwise_mul_insts = None + self.fused_xclbin = None + self.fused_insts = None self.gemv_2_xclbin = None self.gemv_2_insts = None @@ -49,63 +44,22 @@ def set_up_artifacts(self): artifacts = [] device_str = self.context.device_manager.device_str() - gemv_1 = AIEGEMV( + fused = AIEDualGEMVSiLUMul( M=self.hidden_dim, K=self.embedding_dim, - num_aie_columns=8, + num_aie_columns=4, tile_size_input=4, - tile_size_output=self.hidden_dim // 8, - ) - self.gemv_1 = gemv_1 - gemv_1_xclbin, gemv_1_insts = gemv_1.get_artifacts( - prefix="swiglu_decode_gemv_1_" + tile_size_output=self.hidden_dim // 4, ) - gemv_1_xclbin.extra_flags += [ - "--xclbin-instance-name=swiglu_gemv_1", + self.fused = fused + self.hidden_dim_padded = self.hidden_dim + fused_xclbin, fused_insts = fused.get_artifacts(prefix="swiglu_decode_fused_") + fused_xclbin.extra_flags += [ + "--xclbin-instance-name=swiglu_fused", "--xclbin-kernel-id=0x901", ] - gemv_1_xclbin.kernel_name = "swiglu_gemv_1" - artifacts.append( - gemv_1_insts - ) # xclbin artifact will be pulled in as a dependency of last xclbin - - silu = AIESiLU( - size=self.hidden_dim, - num_aie_columns=8, - num_channels=2, - tile_size=self.hidden_dim // 16, - ) - self.silu = silu - self.hidden_dim_padded = silu.size - silu_xclbin, silu_insts = silu.get_artifacts(prefix="swiglu_decode_silu_") - silu_xclbin.xclbin_input = gemv_1_xclbin - silu_xclbin.extra_flags += [ - "--xclbin-instance-name=swiglu_silu", - "--xclbin-kernel-id=0x902", - ] - silu_xclbin.kernel_name = "swiglu_silu" - silu_xclbin.depends += [gemv_1_xclbin] - artifacts.append(silu_insts) - - eltwise_mul = AIEElementwiseMul( - size=self.hidden_dim, - num_aie_columns=8, - num_channels=2, - tile_size=self.hidden_dim // 8, - ) - self.eltwise_mul = eltwise_mul - assert self.hidden_dim <= eltwise_mul.size <= self.hidden_dim_padded - eltwise_mul_xclbin, eltwise_mul_insts = eltwise_mul.get_artifacts( - prefix="swiglu_decode_eltwise_mul_" - ) - eltwise_mul_xclbin.xclbin_input = silu_xclbin - eltwise_mul_xclbin.extra_flags += [ - "--xclbin-instance-name=swiglu_eltwise_mul", - "--xclbin-kernel-id=0x903", - ] - eltwise_mul_xclbin.kernel_name = "swiglu_eltwise_mul" - eltwise_mul_xclbin.depends += [silu_xclbin] - artifacts.append(eltwise_mul_insts) + fused_xclbin.kernel_name = "swiglu_fused" + artifacts.append(fused_insts) gemv_2 = AIEGEMV( M=self.embedding_dim, @@ -118,23 +72,19 @@ def set_up_artifacts(self): gemv_2_xclbin, gemv_2_insts = gemv_2.get_artifacts( prefix="swiglu_decode_gemv_2_" ) - gemv_2_xclbin.xclbin_input = eltwise_mul_xclbin + gemv_2_xclbin.xclbin_input = fused_xclbin gemv_2_xclbin.extra_flags += [ "--xclbin-instance-name=swiglu_gemv_2", - "--xclbin-kernel-id=0x904", + "--xclbin-kernel-id=0x902", ] gemv_2_xclbin.kernel_name = "swiglu_gemv_2" - gemv_2_xclbin.depends += [eltwise_mul_xclbin] + gemv_2_xclbin.depends += [fused_xclbin] artifacts.append(gemv_2_xclbin) artifacts.append(gemv_2_insts) self.combined_xclbin = gemv_2_xclbin - self.gemv_1_xclbin = gemv_1_xclbin - self.gemv_1_insts = gemv_1_insts - self.silu_xclbin = silu_xclbin - self.silu_insts = silu_insts - self.eltwise_mul_xclbin = eltwise_mul_xclbin - self.eltwise_mul_insts = eltwise_mul_insts + self.fused_xclbin = fused_xclbin + self.fused_insts = fused_insts self.gemv_2_xclbin = gemv_2_xclbin self.gemv_2_insts = gemv_2_insts @@ -142,43 +92,28 @@ def set_up_artifacts(self): def set_up_runtime(self): self.add_buffer("input", self.embedding_dim) - self.add_buffer( - "weights_1", - self.embedding_dim * self.hidden_dim_padded, - static_data=torch_to_numpy(self.weights_1), + # Pre-interleave W1 and W2 for the fused dual-GEMV design + rows_per_col = self.hidden_dim // self.fused.num_aie_columns + w_interleaved = interleave_weights( + self.weights_1, self.weights_2, rows_per_col, self.fused.num_aie_columns ) self.add_buffer( - "weights_2", - self.embedding_dim * self.hidden_dim_padded, - static_data=torch_to_numpy(self.weights_2), + "weights_gate_up", + 2 * self.embedding_dim * self.hidden_dim_padded, + static_data=torch_to_numpy(w_interleaved), ) self.add_buffer( "weights_3", self.hidden_dim_padded * self.embedding_dim, static_data=torch_to_numpy(self.weights_3), ) - self.add_buffer("left", self.hidden_dim_padded) - self.add_buffer("left_swished", self.hidden_dim_padded) - self.add_buffer("right", self.hidden_dim_padded) self.add_buffer("intermediate", self.hidden_dim_padded) self.add_buffer("output", self.embedding_dim) self.add_kernel( - "swiglu_gemv_1", - self.combined_xclbin, - self.gemv_1_xclbin.kernel_name, - self.gemv_1_insts, - ) - self.add_kernel( - "swiglu_silu", + "swiglu_fused", self.combined_xclbin, - self.silu_xclbin.kernel_name, - self.silu_insts, - ) - self.add_kernel( - "swiglu_eltwise_mul", - self.combined_xclbin, - self.eltwise_mul_xclbin.kernel_name, - self.eltwise_mul_insts, + self.fused_xclbin.kernel_name, + self.fused_insts, ) self.add_kernel( "swiglu_gemv_2", @@ -186,12 +121,7 @@ def set_up_runtime(self): self.gemv_2_xclbin.kernel_name, self.gemv_2_insts, ) - self.add_to_runlist("swiglu_gemv_1", "weights_1", "input", "left") - self.add_to_runlist("swiglu_gemv_1", "weights_2", "input", "right") - self.add_to_runlist("swiglu_silu", "left", "left_swished") - self.add_to_runlist( - "swiglu_eltwise_mul", "left_swished", "right", "intermediate" - ) + self.add_to_runlist("swiglu_fused", "weights_gate_up", "input", "intermediate") self.add_to_runlist("swiglu_gemv_2", "weights_3", "intermediate", "output") def forward(self, x): diff --git a/iron/operators/swiglu_decode/test.py b/iron/operators/swiglu_decode/test.py index 11b35fa2..e54e336a 100755 --- a/iron/operators/swiglu_decode/test.py +++ b/iron/operators/swiglu_decode/test.py @@ -55,9 +55,6 @@ def test_swiglu_decode(embedding_dim, hidden_dim, aie_context): input_buffers = {"input": golden_ref["input"]} output_buffers = {} intermediate_buffers = { - "left": golden_ref["left"], - "left_swished": golden_ref["left_swished"], - "right": golden_ref["right"], "intermediate": golden_ref["intermediate"], } diff --git a/iron/operators/swiglu_prefill/op.py b/iron/operators/swiglu_prefill/op.py index 2b2aa341..20359b45 100644 --- a/iron/operators/swiglu_prefill/op.py +++ b/iron/operators/swiglu_prefill/op.py @@ -16,8 +16,7 @@ PythonGeneratedMLIRArtifact, ) from iron.operators.gemm.op import AIEGEMM -from iron.operators.silu.op import AIESiLU -from iron.operators.elementwise_mul.op import AIEElementwiseMul +from iron.operators.silu_mul.op import AIESiLUMul from iron.common.utils import torch_to_numpy @@ -39,10 +38,8 @@ def __init__( self.combined_xclbin = None self.gemm_1_xclbin = None self.gemm_1_insts = None - self.silu_xclbin = None - self.silu_insts = None - self.eltwise_mul_xclbin = None - self.eltwise_mul_insts = None + self.silu_mul_xclbin = None + self.silu_mul_insts = None self.gemm_2_xclbin = None self.gemm_2_insts = None @@ -51,7 +48,7 @@ def __init__( def set_up_artifacts(self): # Artifact setup # --- - # Note: All operators (GEMM, SiLU, ElementwiseMul) apply their own padding + # Note: All operators (GEMM, SiLUMul) apply their own padding # to meet hardware alignment requirements. We store the padded dimensions # from GEMM and verify that all operators use consistent padded sizes. artifacts = [] @@ -82,45 +79,26 @@ def set_up_artifacts(self): gemm_1_insts ) # xclbin artifact will be pulled in as a dependency of last xclbin - silu = AIESiLU( + silu_mul = AIESiLUMul( size=self.seq_len_padded * self.hidden_dim_padded, num_aie_columns=8, num_channels=2, tile_size=self.hidden_dim_padded // 8, ) - self.silu = silu - assert silu.size == self.seq_len_padded * self.hidden_dim_padded + self.silu_mul = silu_mul + assert silu_mul.size == self.seq_len_padded * self.hidden_dim_padded - silu_xclbin, silu_insts = silu.get_artifacts(prefix="swiglu_silu_") - silu_xclbin.xclbin_input = gemm_1_xclbin - silu_xclbin.extra_flags += [ - "--xclbin-instance-name=swiglu_silu", - "--xclbin-kernel-id=0x902", - ] - silu_xclbin.kernel_name = "swiglu_silu" - silu_xclbin.depends += [gemm_1_xclbin] - artifacts.append(silu_insts) - - eltwise_mul = AIEElementwiseMul( - size=self.seq_len_padded * self.hidden_dim_padded, - num_aie_columns=8, - num_channels=2, - tile_size=self.hidden_dim_padded // 8, - ) - self.eltwise_mul = eltwise_mul - assert eltwise_mul.size == self.seq_len_padded * self.hidden_dim_padded - - eltwise_mul_xclbin, eltwise_mul_insts = eltwise_mul.get_artifacts( - prefix="swiglu_eltwise_mul_" + silu_mul_xclbin, silu_mul_insts = silu_mul.get_artifacts( + prefix="swiglu_silu_mul_" ) - eltwise_mul_xclbin.xclbin_input = silu_xclbin - eltwise_mul_xclbin.extra_flags += [ - "--xclbin-instance-name=swiglu_eltwise_mul", - "--xclbin-kernel-id=0x903", + silu_mul_xclbin.xclbin_input = gemm_1_xclbin + silu_mul_xclbin.extra_flags += [ + "--xclbin-instance-name=swiglu_silu_mul", + "--xclbin-kernel-id=0x902", ] - eltwise_mul_xclbin.kernel_name = "swiglu_eltwise_mul" - eltwise_mul_xclbin.depends += [silu_xclbin] - artifacts.append(eltwise_mul_insts) + silu_mul_xclbin.kernel_name = "swiglu_silu_mul" + silu_mul_xclbin.depends += [gemm_1_xclbin] + artifacts.append(silu_mul_insts) gemm_2 = AIEGEMM( M=self.seq_len, K=self.hidden_dim, N=self.embedding_dim, **accuracy_flags @@ -131,23 +109,21 @@ def set_up_artifacts(self): assert gemm_2.N == self.embedding_dim_padded gemm_2_xclbin, gemm_2_insts = gemm_2.get_artifacts(prefix="swiglu_gemm_2_") - gemm_2_xclbin.xclbin_input = eltwise_mul_xclbin + gemm_2_xclbin.xclbin_input = silu_mul_xclbin gemm_2_xclbin.extra_flags += [ "--xclbin-instance-name=swiglu_gemm_2", - "--xclbin-kernel-id=0x904", + "--xclbin-kernel-id=0x903", ] gemm_2_xclbin.kernel_name = "swiglu_gemm_2" - gemm_2_xclbin.depends += [eltwise_mul_xclbin] + gemm_2_xclbin.depends += [silu_mul_xclbin] artifacts.append(gemm_2_xclbin) artifacts.append(gemm_2_insts) self.combined_xclbin = gemm_2_xclbin self.gemm_1_xclbin = gemm_1_xclbin self.gemm_1_insts = gemm_1_insts - self.silu_xclbin = silu_xclbin - self.silu_insts = silu_insts - self.eltwise_mul_xclbin = eltwise_mul_xclbin - self.eltwise_mul_insts = eltwise_mul_insts + self.silu_mul_xclbin = silu_mul_xclbin + self.silu_mul_insts = silu_mul_insts self.gemm_2_xclbin = gemm_2_xclbin self.gemm_2_insts = gemm_2_insts @@ -173,7 +149,6 @@ def set_up_runtime(self): static_data=torch_to_numpy(self.weights_3.T), ) self.add_buffer("left", self.seq_len_padded * self.hidden_dim_padded) - self.add_buffer("left_swished", self.seq_len_padded * self.hidden_dim_padded) self.add_buffer("right", self.seq_len_padded * self.hidden_dim_padded) self.add_buffer("intermediate", self.seq_len_padded * self.hidden_dim_padded) self.add_buffer("output", self.seq_len_padded * self.embedding_dim_padded) @@ -184,16 +159,10 @@ def set_up_runtime(self): self.gemm_1_insts, ) self.add_kernel( - "swiglu_silu", + "swiglu_silu_mul", self.combined_xclbin, - self.silu_xclbin.kernel_name, - self.silu_insts, - ) - self.add_kernel( - "swiglu_eltwise_mul", - self.combined_xclbin, - self.eltwise_mul_xclbin.kernel_name, - self.eltwise_mul_insts, + self.silu_mul_xclbin.kernel_name, + self.silu_mul_insts, ) self.add_kernel( "swiglu_gemm_2", @@ -203,10 +172,7 @@ def set_up_runtime(self): ) self.add_to_runlist("swiglu_gemm_1", "input", "weights_1", "left") self.add_to_runlist("swiglu_gemm_1", "input", "weights_2", "right") - self.add_to_runlist("swiglu_silu", "left", "left_swished") - self.add_to_runlist( - "swiglu_eltwise_mul", "left_swished", "right", "intermediate" - ) + self.add_to_runlist("swiglu_silu_mul", "left", "right", "intermediate") self.add_to_runlist("swiglu_gemm_2", "intermediate", "weights_3", "output") def forward(self, x): diff --git a/iron/operators/swiglu_prefill/test.py b/iron/operators/swiglu_prefill/test.py index 75510d63..53f06dae 100755 --- a/iron/operators/swiglu_prefill/test.py +++ b/iron/operators/swiglu_prefill/test.py @@ -56,9 +56,8 @@ def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_c output_buffers = {} intermediate_buffers = { "left": golden_ref["left"], - "left_swished": golden_ref["left_swished"], "right": golden_ref["right"], - # 'intermediate': golden_ref['intermediate'] + "intermediate": golden_ref["intermediate"], } errors, latency_us, bandwidth_gbps = run_test( @@ -70,20 +69,13 @@ def test_swiglu_prefill(seq_len, embedding_dim, hidden_dim, prio_accuracy, aie_c abs_tol=0.7, ) - ref_2 = operator.read_buffer_as_torch( - "left_swished", (seq_len, hidden_dim) - ) * operator.read_buffer_as_torch("right", (seq_len, hidden_dim)) - errors_2 = verify_buffer(operator, "intermediate", ref_2, rel_tol=0.04, abs_tol=0.4) - if errors_2: - errors["intermediate"] = errors_2 - ref_3 = ( operator.read_buffer_as_torch("intermediate", (seq_len, hidden_dim)) @ golden_ref["w_down"] ) errors_3 = verify_buffer(operator, "output", ref_3, rel_tol=0.04, abs_tol=0.4) if errors_3: - errors["output"] = errors_2 + errors["output"] = errors_3 print(f"\nLatency (us): {latency_us:.1f}") print(f"Effective Bandwidth: {bandwidth_gbps:.6e} GB/s\n")