A CUDA implementation of Flash Attention with TopK selection, built using CUTLASS and CUTE (CUDA Templates) libraries.
This project implements an efficient Flash Attention mechanism with TopK selection on NVIDIA GPUs. It uses CUTLASS for efficient matrix operations and CUTE for tensor abstractions.
- Flash Attention implementation optimized for NVIDIA GPUs
- TopK selection integrated into attention computation
- Support for FP16 data type
- Configurable batch size, number of heads, sequence length, and head dimensions
- CPU reference implementation for result verification
- CUDA Toolkit 11.0 or higher
- CMake 3.18 or higher
- NVIDIA GPU with Compute Capability 8.0 or higher (Ampere architecture)
- C++17 compiler
- Clone the repository:
git clone <repository-url>
cd <repository-name>- Build the project:
mkdir build
cd build
cmake ..
makeRun the test program:
./test_flash_attnThis will execute the Flash Attention kernel and print the results.
src/test_flash_attn.cu: Main test file and CPU reference implementationkernel_fwd.h: Forward pass kernel implementationkernel_traits.h: Kernel configuration traitstopk.h: TopK implementationutils.h: Utility functions and helpers
include/: External dependencies and headers
- Uses CUTLASS's MMA (Matrix Multiply-Accumulate) operations for efficient attention computation
- Implements custom TopK selection using bitonic sort
- Supports configurable block sizes and thread organization
- Includes debug printing capabilities for development
- Optimized for Ampere architecture (SM80)
- Uses shared memory for efficient data access
- Implements efficient memory access patterns
- Utilizes Tensor Core matrix operations
- NVIDIA CUTLASS library
- CUDA Templates (CUTE) library