From b5b20b814f60f6435b73b75653c1a8a7244acefa Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 02:50:56 +0000 Subject: [PATCH 01/37] update --- csrc/CMakeLists.txt | 93 ++++++++++++++++++++++++++------------------- csrc/setup.py | 15 ++++++++ csrc/tp.py | 75 ++++++++++++++++++++++++++++++++++++ csrc/yest | 3 ++ yes.py | 12 ++++++ 5 files changed, 159 insertions(+), 39 deletions(-) create mode 100644 csrc/setup.py create mode 100644 csrc/tp.py create mode 100644 csrc/yest create mode 100644 yes.py diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 5b305fa98c0..04597cf3e28 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -13,38 +13,38 @@ set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) set(FA2_SOURCES_CU flash_attn/src/cuda_utils.cu - flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu ) add_library(flashattn SHARED @@ -57,13 +57,13 @@ target_include_directories(flashattn PRIVATE set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu - flash_attn_with_bias_and_mask/src/cuda_utils.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu + #flash_attn_with_bias_and_mask/src/cuda_utils.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu flash_attn_with_bias_and_mask/src/utils.cu) add_library(flashattn_with_bias_mask STATIC @@ -131,7 +131,22 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$) +# INSTALL(TARGETS flashattn LIBRARY DESTINATION "lib") INSTALL(FILES capi/flash_attn.h DESTINATION "include") + +add_custom_target(run_my_executable + COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + DEPENDS flashattn + COMMENT "Running my_executable" +) + +# 创建一个伪目标作为默认构建目标 +add_custom_target(default_target DEPENDS run_my_executable) + +# 设置 'default_target' 为默认构建目标 +set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target) + diff --git a/csrc/setup.py b/csrc/setup.py new file mode 100644 index 00000000000..92a9a3f16ff --- /dev/null +++ b/csrc/setup.py @@ -0,0 +1,15 @@ +from setuptools import setup, find_packages +from setuptools import setup, find_namespace_packages + +setup( + packages=find_packages(where="src"), + package_dir={"": "src"}, + package_data={"": ["*.so"]}, + exclude_package_data={"flash_attn_with_bias_and_mask": ["*"]}, + include_package_data=True, + #packages=find_namespace_packages(where="src"), + #package_dir={"": "src"}, + #package_data={ + # "": ["*.so"], + #} +) diff --git a/csrc/tp.py b/csrc/tp.py new file mode 100644 index 00000000000..5b5d9d95d8c --- /dev/null +++ b/csrc/tp.py @@ -0,0 +1,75 @@ +import paddle +from setuptools import setup, find_packages +import sys +import os + +python_version = sys.version +print("Installing your_package...") + +# Get the CUDA version from PaddlePaddle +cuda_version = paddle.version.cuda() +fa_version = f"1.0.0.post{cuda_version}" +package_name = 'flash_attention_paddle_gpu' + +def get_data_files(): + data_files = [] + + # Assuming 'libflashattn.so' is located in the same directory as setup.py + source_lib_path = 'libflashattn.so' + + # Specify the destination directory within the package + destination_lib_path = os.path.join(package_name, 'libflashattn.so') + + data_files.append((os.path.join(package_name, 'libflashattn.so'), [source_lib_path])) + print(destination_lib_path, "asdf ****************") + print(data_files) + return data_files + +setup( + name=package_name, + version=fa_version, + data_files=get_data_files(), + description='Flash attention in paddlepaddle', + packages=find_packages(), + package_data={package_name: ['src/libflashattn.so']}, +) +# +#import paddle +#import os +#from setuptools import setup +#import sys +# +#python_version = sys.version +#print("Installing your_package...") +# +## Get the CUDA version from PaddlePaddle +#cuda_version = paddle.version.cuda() +#fa_version = f"1.0.0.post{cuda_version}" +#package_name = 'flash_attention_paddle_gpu' # Adjusted package name +# +#def get_data_files(): +# data_files = [] +# +# # Assuming 'libflashattn.so' is located in the same directory as setup.py +# source_lib_path = os.path.abspath('libflashattn.so') +# +# # Specify the destination directory within the package +# destination_lib_path = os.path.join(package_name, 'libflashattn.so') +# +# data_files.append((os.path.join(package_name, 'libflashattn.so'), [source_lib_path])) +# print(destination_lib_path, "asdf ****************") +# print(data_files) +# return data_files +# +## Create an empty __init__.py file in the package directory +#init_file_path = os.path.join(package_name, '__init__.py') +#with open(init_file_path, 'w') as f: +# pass +# +#setup( +# name=package_name, +# version=fa_version, +# description='Flash attention in paddlepaddle', +# packages=[package_name], +# package_data={package_name: ['libflashattn.so']}, +#) diff --git a/csrc/yest b/csrc/yest new file mode 100644 index 00000000000..b3d4d3cd0b5 --- /dev/null +++ b/csrc/yest @@ -0,0 +1,3 @@ +include build/libflashattn.so +include src/libflashattn.so +include ./libflashattn.so diff --git a/yes.py b/yes.py new file mode 100644 index 00000000000..29917c43ecc --- /dev/null +++ b/yes.py @@ -0,0 +1,12 @@ +from setuptools import setup + +package_name = '' #flash-attention-paddle-gpu' +setup( + name=package_name, + version='1.0.0', + description='Flash attention in PaddlePaddle', + packages=[package_name], + include_package_data=True, + package_data={package_name: ['csrc/build/libflashattn.so']}, +) + From 199b9d68ad1d89e724987bbd9de16a92a7c6cf38 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 03:22:28 +0000 Subject: [PATCH 02/37] has data --- csrc/CMakeLists.txt | 2 +- csrc/tp.py | 6 ++++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 04597cf3e28..639d8b1eb4f 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -138,7 +138,7 @@ INSTALL(TARGETS flashattn INSTALL(FILES capi/flash_attn.h DESTINATION "include") add_custom_target(run_my_executable - COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist + COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/tp.py sdist bdist_wheel WORKING_DIRECTORY ${CMAKE_BINARY_DIR} DEPENDS flashattn COMMENT "Running my_executable" diff --git a/csrc/tp.py b/csrc/tp.py index 5b5d9d95d8c..56aebee074d 100644 --- a/csrc/tp.py +++ b/csrc/tp.py @@ -2,7 +2,9 @@ from setuptools import setup, find_packages import sys import os - +import paddle +paddle_path = paddle.sysconfig.get_lib +print(paddle_path) python_version = sys.version print("Installing your_package...") @@ -31,7 +33,7 @@ def get_data_files(): data_files=get_data_files(), description='Flash attention in paddlepaddle', packages=find_packages(), - package_data={package_name: ['src/libflashattn.so']}, + package_data={package_name: ['build/libflashattn.so']}, ) # #import paddle From a582b3a9e571bea2f65284c0626756ad7ff589a3 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 04:47:09 +0000 Subject: [PATCH 03/37] update --- csrc/README.md | 234 ++++++++++++++++++++++++++++++++++ csrc/mp.py | 336 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 570 insertions(+) create mode 100644 csrc/README.md create mode 100644 csrc/mp.py diff --git a/csrc/README.md b/csrc/README.md new file mode 100644 index 00000000000..79d33453003 --- /dev/null +++ b/csrc/README.md @@ -0,0 +1,234 @@ +# FlashAttention +This repository provides the official implementation of FlashAttention and +FlashAttention-2 from the +following papers. + +**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness** +Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré +Paper: https://arxiv.org/abs/2205.14135 +IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention. +![FlashAttention](assets/flashattn_banner.jpg) + +**FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning** +Tri Dao + +Paper: https://tridao.me/publications/flash2/flash2.pdf + +![FlashAttention-2](assets/flashattention_logo.png) + + +## Usage + +We've been very happy to see FlashAttention being widely adopted in such a short +time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md) +contains a partial list of places where FlashAttention is being used. + +FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). +Please cite and credit FlashAttention if you use it. + +## Installation and features + +Requirements: +- CUDA 11.4 and above. +- PyTorch 1.12 and above. + +We recommend the +[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) +container from Nvidia, which has all the required tools to install FlashAttention. + +To install: +1. Make sure that PyTorch is installed. +2. Make sure that `packaging` is installed (`pip install packaging`) +3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja +--version` then `echo $?` should return exit code 0). If not (sometimes `ninja +--version` then `echo $?` returns a nonzero exit code), uninstall then reinstall +`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`, +compiling can take a very long time (2h) since it does not use multiple CPU +cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine. +4. Then: +```sh +pip install flash-attn --no-build-isolation +``` +Alternatively you can compile from source: +```sh +python setup.py install +``` + +If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might +run too many parallel compilation jobs that could exhaust the amount of RAM. To +limit the number of parallel compilation jobs, you can set the environment +variable `MAX_JOBS`: +```sh +MAX_JOBS=4 pip install flash-attn --no-build-isolation +``` + +Interface: `src/flash_attention_interface.py` + +FlashAttention-2 currently supports: +1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing + GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing + GPUs for now. +2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs). +3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800. + + +## How to use FlashAttention + +The main functions implement scaled dot product attention (softmax(Q @ K^T * +softmax_scale) @ V): +```python +from flash_attn import flash_attn_qkvpacked_func, flash_attn_func +``` + +```python +flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False): +"""dropout_p should be set to 0.0 during evaluation +If Q, K, V are already stacked into 1 tensor, this function will be faster than +calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation +of the gradients of Q, K, V. +Arguments: + qkv: (batch_size, seqlen, 3, nheads, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). +Return: + out: (batch_size, seqlen, nheads, headdim). +""" +``` + +```python +flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False): +"""dropout_p should be set to 0.0 during evaluation +Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads +than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. +For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head +0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. + +Arguments: + q: (batch_size, seqlen, nheads, headdim) + k: (batch_size, seqlen, nheads_k, headdim) + v: (batch_size, seqlen, nheads_k, headdim) + dropout_p: float. Dropout probability. + softmax_scale: float. The scaling of QK^T before applying softmax. + Default to 1 / sqrt(headdim). + causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). +Return: + out: (batch_size, seqlen, nheads, headdim). +""" +``` + +To see how these functions are used in a multi-head attention layer (which +includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py). + +## Upgrading from FlashAttention (1.x) to FlashAttention-2 + +These functions have been renamed: +- `flash_attn_unpadded_func` -> `flash_attn_varlen_func` +- `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func` +- `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func` + +If the inputs have the same sequence lengths in the same batch, it is simpler +and faster to use these functions: +```python +flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False) +``` +```python +flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) +``` + +## Performance + +We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory). + +We currently have benchmarks for these GPUs: +* [A100](#a100) +* [H100](#h100) + + + +### A100 + +We display FlashAttention speedup using these parameters: +* Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads). +* Sequence length 512, 1k, 2k, 4k, 8k, 16k. +* Batch size set to 16k / seqlen. + +#### Speedup + +![FlashAttention speedup on A100 80GB SXM5 with FP16/BF16](assets/flash2_a100_fwd_bwd_benchmark.png) + +#### Memory + +![FlashAttention memory](assets/flashattn_memory.jpg) + +We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking). +Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length. +We see 10X memory savings at sequence length 2K, and 20X at 4K. +As a result, FlashAttention can scale to much longer sequence lengths. + +### H100 + +![FlashAttention speedup on H100 SXM5 with FP16/BF16](assets/flash2_h100_fwd_bwd_benchmark.png) + +## Full model code and training script + +We have released the full GPT model +[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py). +We also provide optimized implementations of other layers (e.g., MLP, LayerNorm, +cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x +compared to the baseline implementation from Huggingface, reaching up to 225 +TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need +any activation checkpointing). + +We also include a training +[script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to +train GPT2 on Openwebtext and GPT3 on The Pile. + +## Triton implementation of FlashAttention + +Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton: +https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py + +As Triton is a higher-level language than CUDA, it might be easier to understand +and experiment with. The notations in the Triton implementation are also closer +to what's used in our paper. + +We also have an experimental implementation in Triton that support attention +bias (e.g. ALiBi): +https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py + + +## Tests +We test that FlashAttention produces the same output and gradient as a reference +implementation, up to some numerical tolerance. In particular, we check that the +maximum numerical error of FlashAttention is at most twice the numerical error +of a baseline implementation in Pytorch (for different head dimensions, input +dtype, sequence length, causal / non-causal). + +To run the tests: +```sh +pytest -q -s tests/test_flash_attn.py +``` +## When you encounter issues + +This new release of FlashAttention-2 has been tested on several GPT-style +models, mostly on A100 GPUs. + +If you encounter bugs, please open a GitHub Issue! + +## Citation +If you use this codebase, or otherwise found our work valuable, please cite: +``` +@inproceedings{dao2022flashattention, + title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, + author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, + booktitle={Advances in Neural Information Processing Systems}, + year={2022} +} +@article{dao2023flashattention2, + title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning, + author={Dao, Tri}, + year={2023} +} +``` diff --git a/csrc/mp.py b/csrc/mp.py new file mode 100644 index 00000000000..c379c9c4ee5 --- /dev/null +++ b/csrc/mp.py @@ -0,0 +1,336 @@ +# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import sys +import warnings +import os +import re +import ast +from pathlib import Path +from packaging.version import parse, Version +import platform + +from setuptools import setup, find_packages +import subprocess + +import urllib.request +import urllib.error +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import paddle +from paddle.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension +from paddle.utils.cpp_extension.extension_utils import find_cuda_home + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) +CUDA_HOME = find_cuda_home() +PACKAGE_NAME = "flash_attn" + +BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" + +# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels +# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI +FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" +# For CI, we want the option to not add "--threads 4" to nvcc, since the runner can OOM +FORCE_SINGLE_THREAD = os.getenv("FLASH_ATTENTION_FORCE_SINGLE_THREAD", "FALSE") == "TRUE" + + +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith('linux'): + return 'linux_x86_64' + elif sys.platform == 'darwin': + mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2]) + return f'macosx_{mac_version}_x86_64' + elif sys.platform == 'win32': + return 'win_amd64' + else: + raise ValueError('Unsupported platform: {}'.format(sys.platform)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_cuda_paddle_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) + paddle_binary_version = parse(paddle.version.cuda()) + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if (bare_metal_version != paddle_binary_version): + raise RuntimeError( + "Cuda extensions are being compiled with a version of Cuda that does " + "not match the version used to compile Pypaddle binaries. " + "Pypaddle binaries were compiled with Cuda {}.\n".format(paddle.version.cuda) + + "In some cases, a minor-version mismatch will not cause later errors: " + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " + "You can try commenting out this check (at your own risk)." + ) + + +def raise_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + raise RuntimeError( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pypaddle/pypaddle, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + if not FORCE_SINGLE_THREAD: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + +def _is_cuda_available(): + """ + Check whether CUDA is available. + """ + try: + assert len(paddle.static.cuda_places()) > 0 + return True + except Exception as e: + logging.warning( + "You are using GPU version PaddlePaddle, but there is no GPU " + "detected on your machine. Maybe CUDA devices is not set properly." + f"\n Original Error is {e}" + ) + return False + +if paddle.is_compiled_with_cuda() and _is_cuda_available(): + # https://github.com/NVIDIA/apex/issues/486 + # Extension builds after https://github.com/pypaddle/pypaddle/pull/23408 attempt to query paddle.cuda.get_device_capability(), + # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). + print( + "\nWarning: Torch did not find available GPUs on this system.\n", + "If your intention is to cross-compile, this is not an error.\n" + "By default, FlashAttention will cross-compile for Ampere (compute capability 8.0, 8.6, " + "8.9), and, if the CUDA version is >= 11.8, Hopper (compute capability 9.0).\n" + "If you wish to cross-compile for a single specific architecture,\n" + 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', + ) + if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.8"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;9.0" + elif bare_metal_version >= Version("11.4"): + os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" + else: + os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" + +cmdclass = {} +ext_modules = [] + +# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp +# files included in the source distribution, in case the user compiles from source. +subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) + +if not SKIP_CUDA_BUILD: + print("\n\npaddle.__version__ = {}\n\n".format(paddle.__version__)) + TORCH_MAJOR = int(paddle.__version__.split(".")[0]) + TORCH_MINOR = int(paddle.__version__.split(".")[1]) + + # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h + # See https://github.com/pypaddle/pypaddle/pull/70650 + generator_flag = [] + paddle_dir = paddle.__path__[0] + if os.path.exists(os.path.join(paddle_dir, "include", "ATen", "CUDAGeneratorImpl.h")): + generator_flag = ["-DOLD_GENERATOR_PATH"] + + raise_if_cuda_home_none("flash_attn") + # Check, if CUDA11 is installed for compute capability 8.0 + cc_flag = [] + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version < Version("11.4"): + raise RuntimeError("FlashAttention is only supported on CUDA 11.4 and above") + # cc_flag.append("-gencode") + # cc_flag.append("arch=compute_75,code=sm_75") + cc_flag.append("-gencode") + cc_flag.append("arch=compute_80,ggcode=sm_80") + if bare_metal_version >= Version("11.8"): + cc_flag.append("-gencode") + cc_flag.append("arch=compute_90,code=sm_90") + + # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as + # paddle._C._GLIBCXX_USE_CXX11_ABI + # https://github.com/pypaddle/pypaddle/blob/8472c24e3b5b60150096486616d98b7bea01500b/paddle/utils/cpp_extension.py#L920 + if FORCE_CXX11_ABI: + paddle._C._GLIBCXX_USE_CXX11_ABI = True + ext_modules.append( + CUDAExtension( + sources=[ + "csrc/flash_attn/flash_api.cpp", + "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", + "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", + "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"] + generator_flag, + "nvcc": append_nvcc_threads( + [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", + # "--ptxas-options=-O2", + "-lineinfo" + ] + + generator_flag + + cc_flag + ), + }, + include_dirs=[ + Path(this_dir) / 'csrc' / 'flash_attn', + Path(this_dir) / 'csrc' / 'flash_attn' / 'src', + Path(this_dir) / 'csrc' / 'cutlass' / 'include', + ], + ) + ) + + +def get_package_version(): + with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f: + version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) + public_version = ast.literal_eval(version_match.group(1)) + local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") + if local_version: + return f"{public_version}+{local_version}" + else: + return str(public_version) + + +class CachedWheelsCommand(_bdist_wheel): + """ + The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot + find an existing wheel (which is currently the case for all flash attention installs). We use + the environment parameters to detect whether there is already a pre-built version of a compatible + wheel available and short-circuits the standard full build pipeline. + """ + def run(self): + if FORCE_BUILD: + return super().run() + + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build paddle, not the one currently installed + # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) + paddle_cuda_version = parse(paddle.version.cuda) + paddle_version_raw = parse(paddle.__version__) + python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + platform_name = get_platform() + flash_version = get_package_version() + # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" + cuda_version = f"{paddle_cuda_version.major}{paddle_cuda_version.minor}" + paddle_version = f"{paddle_version_raw.major}.{paddle_version_raw.minor}" + cxx11_abi = str(paddle._C._GLIBCXX_USE_CXX11_ABI).upper() + + # Determine wheel URL based on CUDA version, paddle version, python version and OS + wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}paddle{paddle_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl' + wheel_url = BASE_WHEEL_URL.format( + tag_name=f"v{flash_version}", + wheel_name=wheel_filename + ) + print("Guessing wheel URL: ", wheel_url) + + try: + urllib.request.urlretrieve(wheel_url, wheel_filename) + + # Make the archive + # Lifted from the root wheel processing command + # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + + wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") + print("Raw wheel path", wheel_path) + os.rename(wheel_filename, wheel_path) + except urllib.error.HTTPError: + print("Precompiled wheel not found. Building from source...") + # If the wheel could not be downloaded, build from source + super().run() + + +setup( + name=PACKAGE_NAME, + version=get_package_version(), + packages=find_packages( + exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) + ), + author="Tri Dao", + author_email="trid@cs.stanford.edu", + description="Flash Attention: Fast and Memory-Efficient Exact Attention", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/Dao-AILab/flash-attention", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + ], + #ext_modules=ext_modules, + cmdclass={ + 'bdist_wheel': CachedWheelsCommand, + "build_ext": BuildExtension + } if ext_modules else { + 'bdist_wheel': CachedWheelsCommand, + }, + python_requires=">=3.7", + install_requires=[ + "paddle", + "einops", + "packaging", + "ninja", + ], +) From 78080ddb5dfdec3a0154698ea81ccc3a45826c39 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 13:02:53 +0800 Subject: [PATCH 04/37] update --- csrc/mp.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/mp.py b/csrc/mp.py index c379c9c4ee5..ab8435e8e71 100644 --- a/csrc/mp.py +++ b/csrc/mp.py @@ -239,7 +239,7 @@ def _is_cuda_available(): def get_package_version(): - with open(Path(this_dir) / "flash_attn" / "__init__.py", "r") as f: + with open(Path(this_dir) / "../flash_attn" / "__init__.py", "r") as f: version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) public_version = ast.literal_eval(version_match.group(1)) local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") From 0d6766ef6c134f8a550db884409238c949c32ef8 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 14:19:48 +0800 Subject: [PATCH 05/37] update --- csrc/setup.py | 255 +++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 244 insertions(+), 11 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index 92a9a3f16ff..33bc82050c4 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -1,15 +1,248 @@ +# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py +import sys +import warnings +import os +import re +import ast +from pathlib import Path +from packaging.version import parse, Version +import platform + from setuptools import setup, find_packages -from setuptools import setup, find_namespace_packages +import subprocess + +import urllib.request +import urllib.error +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import paddle +from paddle.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension +from paddle.utils.cpp_extension.extension_utils import find_cuda_home + +with open("README.md", "r", encoding="utf-8") as fh: + long_description = fh.read() + + +# ninja build does not work unless include_dirs are abs path +this_dir = os.path.dirname(os.path.abspath(__file__)) +CUDA_HOME = find_cuda_home() +PACKAGE_NAME = "flash_attn" + +BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" + +# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels +# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation +FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI +FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" +# For CI, we want the option to not add "--threads 4" to nvcc, since the runner can OOM +FORCE_SINGLE_THREAD = os.getenv("FLASH_ATTENTION_FORCE_SINGLE_THREAD", "FALSE") == "TRUE" + + +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith('linux'): + return 'linux_x86_64' + elif sys.platform == 'darwin': + mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2]) + return f'macosx_{mac_version}_x86_64' + elif sys.platform == 'win32': + return 'win_amd64' + else: + raise ValueError('Unsupported platform: {}'.format(sys.platform)) + + +def get_cuda_bare_metal_version(cuda_dir): + raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + output = raw_output.split() + release_idx = output.index("release") + 1 + bare_metal_version = parse(output[release_idx].split(",")[0]) + + return raw_output, bare_metal_version + + +def check_cuda_paddle_binary_vs_bare_metal(cuda_dir): + raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) + paddle_binary_version = parse(paddle.version.cuda()) + + print("\nCompiling cuda extensions with") + print(raw_output + "from " + cuda_dir + "/bin\n") + + if (bare_metal_version != paddle_binary_version): + raise RuntimeError( + "Cuda extensions are being compiled with a version of Cuda that does " + "not match the version used to compile Pypaddle binaries. " + "Pypaddle binaries were compiled with Cuda {}.\n".format(paddle.version.cuda) + + "In some cases, a minor-version mismatch will not cause later errors: " + "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " + "You can try commenting out this check (at your own risk)." + ) + + +def raise_if_cuda_home_none(global_option: str) -> None: + if CUDA_HOME is not None: + return + raise RuntimeError( + f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " + "If you're installing within a container from https://hub.docker.com/r/pypaddle/pypaddle, " + "only images whose names contain 'devel' will provide nvcc." + ) + + +def append_nvcc_threads(nvcc_extra_args): + if not FORCE_SINGLE_THREAD: + return nvcc_extra_args + ["--threads", "4"] + return nvcc_extra_args + +def _is_cuda_available(): + """ + Check whether CUDA is available. + """ + try: + assert len(paddle.static.cuda_places()) > 0 + return True + except Exception as e: + logging.warning( + "You are using GPU version PaddlePaddle, but there is no GPU " + "detected on your machine. Maybe CUDA devices is not set properly." + f"\n Original Error is {e}" + ) + return False + +if paddle.is_compiled_with_cuda() and _is_cuda_available(): + # https://github.com/NVIDIA/apex/issues/486 + # Extension builds after https://github.com/pypaddle/pypaddle/pull/23408 attempt to query paddle.cuda.get_device_capability(), + # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). + print( + "\nWarning: Torch did not find available GPUs on this system.\n", + "If your intention is to cross-compile, this is not an error.\n" + "By default, FlashAttention will cross-compile for Ampere (compute capability 8.0, 8.6, " + "8.9), and, if the CUDA version is >= 11.8, Hopper (compute capability 9.0).\n" + "If you wish to cross-compile for a single specific architecture,\n" + 'export PADDLE_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', + ) + if os.environ.get("PADDLE_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: + _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) + if bare_metal_version >= Version("11.8"): + os.environ["PADDLE_CUDA_ARCH_LIST"] = "8.0;8.6;9.0" + elif bare_metal_version >= Version("11.4"): + os.environ["PADDLE_CUDA_ARCH_LIST"] = "8.0;8.6" + else: + os.environ["PADDLE_CUDA_ARCH_LIST"] = "8.0;8.6" + +cmdclass = {} +ext_modules = [] + +def get_package_version(): + with open(Path(this_dir) / "../flash_attn" / "__init__.py", "r") as f: + version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) + public_version = ast.literal_eval(version_match.group(1)) + local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") + if local_version: + return f"{public_version}+{local_version}" + else: + return str(public_version) + +def get_data_files(): + data_files = [] + + # Assuming 'libflashattn.so' is located in the same directory as setup.py + source_lib_path = 'libflashattn.so' + + # Specify the destination directory within the package + destination_lib_path = os.path.join(PACKAGE_NAME, 'libflashattn.so') + + data_files.append((os.path.join(PACKAGE_NAME, 'libflashattn.so'), [source_lib_path])) + return data_files + + +class CachedWheelsCommand(_bdist_wheel): + """ + The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot + find an existing wheel (which is currently the case for all flash attention installs). We use + the environment parameters to detect whether there is already a pre-built version of a compatible + wheel available and short-circuits the standard full build pipeline. + """ + def run(self): + if FORCE_BUILD: + return super().run() + + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build paddle, not the one currently installed + # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) + paddle_cuda_version = parse(paddle.version.cuda) + paddle_version_raw = parse(paddle.__version__) + python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + platform_name = get_platform() + flash_version = get_package_version() + # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" + cuda_version = f"{paddle_cuda_version.major}{paddle_cuda_version.minor}" + paddle_version = f"{paddle_version_raw.major}.{paddle_version_raw.minor}" + cxx11_abi = str(paddle._C._GLIBCXX_USE_CXX11_ABI).upper() + + # Determine wheel URL based on CUDA version, paddle version, python version and OS + wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}paddle{paddle_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl' + wheel_url = BASE_WHEEL_URL.format( + tag_name=f"v{flash_version}", + wheel_name=wheel_filename + ) + print("Guessing wheel URL: ", wheel_url) + + try: + urllib.request.urlretrieve(wheel_url, wheel_filename) + + # Make the archive + # Lifted from the root wheel processing command + # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + + wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") + print("Raw wheel path", wheel_path) + os.rename(wheel_filename, wheel_path) + except urllib.error.HTTPError: + print("Precompiled wheel not found. Building from source...") + # If the wheel could not be downloaded, build from source + super().run() + setup( - packages=find_packages(where="src"), - package_dir={"": "src"}, - package_data={"": ["*.so"]}, - exclude_package_data={"flash_attn_with_bias_and_mask": ["*"]}, - include_package_data=True, - #packages=find_namespace_packages(where="src"), - #package_dir={"": "src"}, - #package_data={ - # "": ["*.so"], - #} + name=PACKAGE_NAME, + version=get_package_version(), + packages=find_packages( + #exclude=("build") + #, "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) + ), + data_files=get_data_files(), + package_data={PACKAGE_NAME: ['build/libflashattn.so']}, + author_email="Paddle-better@baidu.com", + description="Flash Attention: Fast and Memory-Efficient Exact Attention", + long_description=long_description, + long_description_content_type="text/markdown", + url="https://github.com/PaddlePaddle/flash-attention", + classifiers=[ + "Programming Language :: Python :: 3", + "License :: OSI Approved :: BSD License", + "Operating System :: Unix", + ], + ext_modules=ext_modules, + cmdclass={ + 'bdist_wheel': CachedWheelsCommand, + "build_ext": BuildExtension + } if ext_modules else { + 'bdist_wheel': CachedWheelsCommand, + }, + python_requires=">=3.7", + install_requires=[ + "paddle", + "einops", + "packaging", + "ninja", + ], ) From 17b89c9047017af7cb127bf5c1eec349e9045b5a Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 16:18:14 +0800 Subject: [PATCH 06/37] updat --- csrc/CMakeLists.txt | 2 +- csrc/setup.py | 287 ++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 281 insertions(+), 8 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 639d8b1eb4f..3b84b53889c 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -138,7 +138,7 @@ INSTALL(TARGETS flashattn INSTALL(FILES capi/flash_attn.h DESTINATION "include") add_custom_target(run_my_executable - COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/tp.py sdist bdist_wheel + COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py sdist bdist_wheel WORKING_DIRECTORY ${CMAKE_BINARY_DIR} DEPENDS flashattn COMMENT "Running my_executable" diff --git a/csrc/setup.py b/csrc/setup.py index 33bc82050c4..47d6a955503 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -8,8 +8,135 @@ from packaging.version import parse, Version import platform +from setuptools import Command, Extension, setup +from setuptools.command.develop import develop as DevelopCommandBase +from setuptools.command.egg_info import egg_info +from setuptools.command.install import install as InstallCommandBase +from setuptools.command.install_lib import install_lib +from setuptools.dist import Distribution from setuptools import setup, find_packages import subprocess +python_version = platform.python_version() +version_detail = sys.version_info +version = version_detail[0] + version_detail[1] / 10 +env_version = os.getenv("PY_VERSION") + +if version < 3.7: + raise RuntimeError( + f"Paddle only supports Python version >= 3.7 now," + f"you are using Python {python_version}" + ) +elif env_version is None: + print(f"export PY_VERSION = { python_version }") + os.environ["PY_VERSION"] = python_version + +elif env_version != version: + warnings.warn( + f"You set PY_VERSION={env_version}, but" + f"your current python environment is {version}" + f"we will use your current python version to execute" + ) + os.environ["PY_VERSION"] = python_version + + +global env_dict # noqa: F811 +env_dict={ + 'PADDLE_SOURCE_DIR':'@PADDLE_SOURCE_DIR@', + 'PADDLE_VERSION':'@PADDLE_VERSION@', + 'PADDLE_BINARY_DIR':'@PADDLE_BINARY_DIR@', + 'TAG_VERSION_REGEX':'@TAG_VERSION_REGEX@', + 'WITH_GPU':'@WITH_GPU@', + 'CUDNN_MAJOR_VERSION':'@CUDNN_MAJOR_VERSION@', + 'CUDNN_MINOR_VERSION':'@CUDNN_MINOR_VERSION@', + 'CUDNN_PATCHLEVEL_VERSION':'@CUDNN_PATCHLEVEL_VERSION@', + 'CUDA_VERSION':'@CUDA_VERSION@', + 'WITH_PSLI':'@WITH_PSLI@', + 'FLUID_CORE_NAME':'@FLUID_CORE_NAME@', + 'PHI_LIB':'@PHI_LIB@', + 'PHI_NAME':'@PHI_NAME@', + 'WITH_SHARED_PHI':'@WITH_SHARED_PHI@', + 'IR_LIB':'@IR_LIB@', + 'IR_NAME':'@IR_NAME@', + 'WITH_SHARED_IR':'@WITH_SHARED_IR@', + 'COMMON_LIB':'@COMMON_LIB@', + 'COMMON_NAME':'@COMMON_NAME@', + 'WARPCTC_LIBRARIES':'@WARPCTC_LIBRARIES@', + 'WARPRNNT_LIBRARIES':'@WARPRNNT_LIBRARIES@', + 'FLASHATTN_LIBRARIES':'@FLASHATTN_LIBRARIES@', + 'LAPACK_LIB':'@LAPACK_LIB@', + 'GFORTRAN_LIB':'@GFORTRAN_LIB@', + 'GNU_RT_LIB_1':'@GNU_RT_LIB_1@', + 'WITH_CUDNN_DSO':'@WITH_CUDNN_DSO@', + 'CUDNN_LIBRARY':'@CUDNN_LIBRARY@', + 'GNU_RT_LIB_2':'@GNU_RT_LIB_2@', + 'WITH_MKL':'@WITH_MKL@', + 'MKLML_SHARED_LIB':'@MKLML_SHARED_LIB@', + 'MKLML_SHARED_IOMP_LIB':'@MKLML_SHARED_IOMP_LIB@', + 'OPENBLAS_SHARED_LIB':'@OPENBLAS_SHARED_LIB@', + 'OPENBLAS_LIB':'@OPENBLAS_LIB@', + 'BLAS_LIB':'@BLAS_LIB@', + 'WITH_LITE':'@WITH_LITE@', + 'LITE_SHARED_LIB':'@LITE_SHARED_LIB@', + 'LITE_WITH_NNADAPTER':'@LITE_WITH_NNADAPTER@', + 'LITE_NNADAPTER_LIB':'@LITE_NNADAPTER_LIB@', + 'NNADAPTER_WITH_HUAWEI_ASCEND_NPU':'@NNADAPTER_WITH_HUAWEI_ASCEND_NPU@', + 'LITE_NNADAPTER_NPU_LIB':'@LITE_NNADAPTER_NPU_LIB@', + 'WITH_CINN':'@WITH_CINN@', + 'CINN_LIB_LOCATION':'@CINN_LIB_LOCATION@', + 'CINN_LIB_NAME':'@CINN_LIB_NAME@', + 'CINN_INCLUDE_DIR':'@CINN_INCLUDE_DIR@', + 'CMAKE_BUILD_TYPE':'@CMAKE_BUILD_TYPE@', + 'PSLIB_LIB':'@PSLIB_LIB@', + 'JVM_LIB':'@JVM_LIB@', + 'PSLIB_VERSION_PY':'@PSLIB_VERSION_PY@', + 'WITH_MKLDNN':'@WITH_MKLDNN@', + 'MKLDNN_SHARED_LIB':'@MKLDNN_SHARED_LIB@', + 'MKLDNN_INSTALL_DIR':'@MKLDNN_INSTALL_DIR@', + 'WITH_ONNXRUNTIME':'@WITH_ONNXRUNTIME@', + 'ONNXRUNTIME_SHARED_LIB':'@ONNXRUNTIME_SHARED_LIB@', + 'PADDLE2ONNX_LIB':'@PADDLE2ONNX_LIB@', + 'PADDLE2ONNX_LIB_NAME':'@PADDLE2ONNX_LIB_NAME@', + 'ONNXRUNTIME_LIB_NAME':'@ONNXRUNTIME_LIB_NAME@', + 'WITH_XPU':'@WITH_XPU@', + 'XPU_API_LIB':'@XPU_API_LIB@', + 'XPU_API_LIB_NAME':'@XPU_API_LIB_NAME@', + 'XPU_RT_LIB':'@XPU_RT_LIB@', + 'XPU_RT_LIB_NAME':'@XPU_RT_LIB_NAME@', + 'WITH_XPU_BKCL':'@WITH_XPU_BKCL@', + 'XPU_BKCL_LIB':'@XPU_BKCL_LIB@', + 'XPU_BKCL_LIB_NAME':'@XPU_BKCL_LIB_NAME@', + 'WITH_XPU_XFT':'@WITH_XPU_XFT@', + 'XPU_XFT_LIB':'@XPU_XFT_LIB@', + 'XPU_XFT_LIB_NAME':'@XPU_XFT_LIB_NAME@', + 'THIRD_PARTY_PATH':'@THIRD_PARTY_PATH@', + 'SETUP_LOG_FILE':'@SETUP_LOG_FILE@', + 'WITH_STRIP':'@WITH_STRIP@', + 'PACKAGE_NAME':'@PACKAGE_NAME@', + 'PADDLE_VERSION':'@PADDLE_VERSION@', + 'APPLE':'@APPLE@', + 'externalError_INCLUDE_DIR':'@externalError_INCLUDE_DIR@', + 'WITH_ROCM':'@WITH_ROCM@', + 'ORIGIN':'@ORIGIN@', + 'WIN32':'@WIN32@', + 'JIT_RELEASE_WHL':'@JIT_RELEASE_WHL@', + 'WITH_PSLIB':'@WITH_PSLIB@', + 'PYBIND_INCLUDE_DIR':'@PYBIND_INCLUDE_DIR@', + 'WITH_PYTHON':'@WITH_PYTHON@', + 'WITH_CINN':'@WITH_CINN@', + 'CINN_SOURCE_DIR':'@CINN_SOURCE_DIR@', + 'WITH_CPP_DIST':'@WITH_CPP_DIST@', + 'PADDLE_INSTALL_DIR':'@PADDLE_INSTALL_DIR@', + 'PADDLE_LIB_TEST_DIR':'@PADDLE_LIB_TEST_DIR@' +} + +global paddle_binary_dir, paddle_source_dir + +paddle_binary_dir = env_dict.get("PADDLE_BINARY_DIR") +paddle_source_dir = env_dict.get("PADDLE_SOURCE_DIR") + +# preparing parameters for setup() +paddle_version = env_dict.get("PADDLE_VERSION") +package_name = env_dict.get("PACKAGE_NAME") import urllib.request import urllib.error @@ -19,7 +146,7 @@ from paddle.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension from paddle.utils.cpp_extension.extension_utils import find_cuda_home -with open("README.md", "r", encoding="utf-8") as fh: +with open("../README.md", "r", encoding="utf-8") as fh: long_description = fh.read() @@ -153,7 +280,7 @@ def get_data_files(): source_lib_path = 'libflashattn.so' # Specify the destination directory within the package - destination_lib_path = os.path.join(PACKAGE_NAME, 'libflashattn.so') + destination_lib_path = os.path.join(PACKAGE_NAME, 'build/libflashattn.so') data_files.append((os.path.join(PACKAGE_NAME, 'libflashattn.so'), [source_lib_path])) return data_files @@ -211,6 +338,145 @@ def run(self): # If the wheel could not be downloaded, build from source super().run() +class InstallHeaders(Command): + """Override how headers are copied.""" + + description = 'install C/C++ header files' + + user_options = [ + ('install-dir=', 'd', 'directory to install header files to'), + ('force', 'f', 'force installation (overwrite existing files)'), + ] + + boolean_options = ['force'] + + def initialize_options(self): + self.install_dir = None + self.force = 0 + self.outfiles = [] + + def finalize_options(self): + self.set_undefined_options( + 'install', ('install_headers', 'install_dir'), ('force', 'force') + ) + + def run(self): + hdrs = self.distribution.headers + if not hdrs: + return + self.mkpath(self.install_dir) + for header in hdrs: + install_dir = get_header_install_dir(header) + install_dir = os.path.join( + self.install_dir, os.path.dirname(install_dir) + ) + if not os.path.exists(install_dir): + self.mkpath(install_dir) + (out, _) = self.copy_file(header, install_dir) + self.outfiles.append(out) + # (out, _) = self.mkdir_and_copy_file(header) + # self.outfiles.append(out) + + def get_inputs(self): + return self.distribution.headers or [] + + def get_outputs(self): + return self.outfiles + + +class InstallCommand(InstallCommandBase): + def finalize_options(self): + ret = InstallCommandBase.finalize_options(self) + self.install_lib = self.install_platlib + + self.install_headers = os.path.join( + self.install_platlib, 'paddle', 'include' + ) + return ret + + +class DevelopCommand(DevelopCommandBase): + def run(self): + # copy proto and .so to python_source_dir + fluid_proto_binary_path = ( + paddle_binary_dir + '/python/paddle/base/proto/' + ) + fluid_proto_source_path = ( + paddle_source_dir + '/python/paddle/base/proto/' + ) + distributed_proto_binary_path = ( + paddle_binary_dir + '/python/paddle/distributed/fleet/proto/' + ) + distributed_proto_source_path = ( + paddle_source_dir + '/python/paddle/distributed/fleet/proto/' + ) + os.system(f"rm -rf {fluid_proto_source_path}") + shutil.copytree(fluid_proto_binary_path, fluid_proto_source_path) + os.system(f"rm -rf {distributed_proto_source_path}") + shutil.copytree( + distributed_proto_binary_path, distributed_proto_source_path + ) + shutil.copy( + paddle_binary_dir + '/python/paddle/base/libpaddle.so', + paddle_source_dir + '/python/paddle/base/', + ) + dynamic_library_binary_path = paddle_binary_dir + '/python/paddle/libs/' + dynamic_library_source_path = paddle_source_dir + '/python/paddle/libs/' + for lib_so in os.listdir(dynamic_library_binary_path): + shutil.copy( + dynamic_library_binary_path + lib_so, + dynamic_library_source_path, + ) + # write version.py and cuda_env_config_py to python_source_dir + write_version_py( + filename=f'{paddle_source_dir}/python/paddle/version/__init__.py' + ) + write_cuda_env_config_py( + filename=f'{paddle_source_dir}/python/paddle/cuda_env.py' + ) + write_parameter_server_version_py( + filename='{}/python/paddle/incubate/distributed/fleet/parameter_server/version.py'.format( + paddle_source_dir + ) + ) + DevelopCommandBase.run(self) + + +class EggInfo(egg_info): + """Copy license file into `.dist-info` folder.""" + + def run(self): + # don't duplicate license into `.dist-info` when building a distribution + if not self.distribution.have_run.get('install', True): + self.mkpath(self.egg_info) + #self.copy_file( + # env_dict.get("PADDLE_SOURCE_DIR") + "/LICENSE", self.egg_info + #) + + egg_info.run(self) + + +# class Installlib is rewritten to add header files to .egg/paddle +class InstallLib(install_lib): + def run(self): + self.build() + outfiles = self.install() + hrds = self.distribution.headers + if not hrds: + return + for header in hrds: + install_dir = get_header_install_dir(header) + install_dir = os.path.join( + self.install_dir, 'paddle/include', os.path.dirname(install_dir) + ) + if not os.path.exists(install_dir): + self.mkpath(install_dir) + self.copy_file(header, install_dir) + if outfiles is not None: + # always compile, in case we have any extension stubs to deal with + self.byte_compile(outfiles) + + setup( name=PACKAGE_NAME, @@ -227,17 +493,24 @@ def run(self): long_description_content_type="text/markdown", url="https://github.com/PaddlePaddle/flash-attention", classifiers=[ - "Programming Language :: Python :: 3", + "Programming Language :: Python :: 37", "License :: OSI Approved :: BSD License", "Operating System :: Unix", ], ext_modules=ext_modules, cmdclass={ - 'bdist_wheel': CachedWheelsCommand, - "build_ext": BuildExtension - } if ext_modules else { - 'bdist_wheel': CachedWheelsCommand, + 'install_headers': InstallHeaders, + 'install': InstallCommand, + 'egg_info': EggInfo, + 'install_lib': InstallLib, + 'develop': DevelopCommand, }, + #cmdclass={ + # "bdist_wheel": CachedWheelsCommand, + # "build_ext": BuildExtension + #} if ext_modules else { + # "bdist_wheel": CachedWheelsCommand, + #}, python_requires=">=3.7", install_requires=[ "paddle", From 1355060bb29ad6c703f003d8022fb786cc158a52 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 17:10:56 +0800 Subject: [PATCH 07/37] update --- csrc/setup.py | 198 +++++++++----------------------------------------- 1 file changed, 35 insertions(+), 163 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index 47d6a955503..f2d7357cecf 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -7,7 +7,7 @@ from pathlib import Path from packaging.version import parse, Version import platform - +import shutil from setuptools import Command, Extension, setup from setuptools.command.develop import develop as DevelopCommandBase from setuptools.command.egg_info import egg_info @@ -15,9 +15,17 @@ from setuptools.command.install_lib import install_lib from setuptools.dist import Distribution from setuptools import setup, find_packages +import urllib.request +import urllib.error +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel + +import paddle +from paddle.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension +from paddle.utils.cpp_extension.extension_utils import find_cuda_home import subprocess -python_version = platform.python_version() + version_detail = sys.version_info +python_version = platform.python_version() version = version_detail[0] + version_detail[1] / 10 env_version = os.getenv("PY_VERSION") @@ -38,113 +46,16 @@ ) os.environ["PY_VERSION"] = python_version +paddle_include_path = paddle.sysconfig.get_include() +paddle_lib_path = paddle.sysconfig.get_lib() -global env_dict # noqa: F811 -env_dict={ - 'PADDLE_SOURCE_DIR':'@PADDLE_SOURCE_DIR@', - 'PADDLE_VERSION':'@PADDLE_VERSION@', - 'PADDLE_BINARY_DIR':'@PADDLE_BINARY_DIR@', - 'TAG_VERSION_REGEX':'@TAG_VERSION_REGEX@', - 'WITH_GPU':'@WITH_GPU@', - 'CUDNN_MAJOR_VERSION':'@CUDNN_MAJOR_VERSION@', - 'CUDNN_MINOR_VERSION':'@CUDNN_MINOR_VERSION@', - 'CUDNN_PATCHLEVEL_VERSION':'@CUDNN_PATCHLEVEL_VERSION@', - 'CUDA_VERSION':'@CUDA_VERSION@', - 'WITH_PSLI':'@WITH_PSLI@', - 'FLUID_CORE_NAME':'@FLUID_CORE_NAME@', - 'PHI_LIB':'@PHI_LIB@', - 'PHI_NAME':'@PHI_NAME@', - 'WITH_SHARED_PHI':'@WITH_SHARED_PHI@', - 'IR_LIB':'@IR_LIB@', - 'IR_NAME':'@IR_NAME@', - 'WITH_SHARED_IR':'@WITH_SHARED_IR@', - 'COMMON_LIB':'@COMMON_LIB@', - 'COMMON_NAME':'@COMMON_NAME@', - 'WARPCTC_LIBRARIES':'@WARPCTC_LIBRARIES@', - 'WARPRNNT_LIBRARIES':'@WARPRNNT_LIBRARIES@', - 'FLASHATTN_LIBRARIES':'@FLASHATTN_LIBRARIES@', - 'LAPACK_LIB':'@LAPACK_LIB@', - 'GFORTRAN_LIB':'@GFORTRAN_LIB@', - 'GNU_RT_LIB_1':'@GNU_RT_LIB_1@', - 'WITH_CUDNN_DSO':'@WITH_CUDNN_DSO@', - 'CUDNN_LIBRARY':'@CUDNN_LIBRARY@', - 'GNU_RT_LIB_2':'@GNU_RT_LIB_2@', - 'WITH_MKL':'@WITH_MKL@', - 'MKLML_SHARED_LIB':'@MKLML_SHARED_LIB@', - 'MKLML_SHARED_IOMP_LIB':'@MKLML_SHARED_IOMP_LIB@', - 'OPENBLAS_SHARED_LIB':'@OPENBLAS_SHARED_LIB@', - 'OPENBLAS_LIB':'@OPENBLAS_LIB@', - 'BLAS_LIB':'@BLAS_LIB@', - 'WITH_LITE':'@WITH_LITE@', - 'LITE_SHARED_LIB':'@LITE_SHARED_LIB@', - 'LITE_WITH_NNADAPTER':'@LITE_WITH_NNADAPTER@', - 'LITE_NNADAPTER_LIB':'@LITE_NNADAPTER_LIB@', - 'NNADAPTER_WITH_HUAWEI_ASCEND_NPU':'@NNADAPTER_WITH_HUAWEI_ASCEND_NPU@', - 'LITE_NNADAPTER_NPU_LIB':'@LITE_NNADAPTER_NPU_LIB@', - 'WITH_CINN':'@WITH_CINN@', - 'CINN_LIB_LOCATION':'@CINN_LIB_LOCATION@', - 'CINN_LIB_NAME':'@CINN_LIB_NAME@', - 'CINN_INCLUDE_DIR':'@CINN_INCLUDE_DIR@', - 'CMAKE_BUILD_TYPE':'@CMAKE_BUILD_TYPE@', - 'PSLIB_LIB':'@PSLIB_LIB@', - 'JVM_LIB':'@JVM_LIB@', - 'PSLIB_VERSION_PY':'@PSLIB_VERSION_PY@', - 'WITH_MKLDNN':'@WITH_MKLDNN@', - 'MKLDNN_SHARED_LIB':'@MKLDNN_SHARED_LIB@', - 'MKLDNN_INSTALL_DIR':'@MKLDNN_INSTALL_DIR@', - 'WITH_ONNXRUNTIME':'@WITH_ONNXRUNTIME@', - 'ONNXRUNTIME_SHARED_LIB':'@ONNXRUNTIME_SHARED_LIB@', - 'PADDLE2ONNX_LIB':'@PADDLE2ONNX_LIB@', - 'PADDLE2ONNX_LIB_NAME':'@PADDLE2ONNX_LIB_NAME@', - 'ONNXRUNTIME_LIB_NAME':'@ONNXRUNTIME_LIB_NAME@', - 'WITH_XPU':'@WITH_XPU@', - 'XPU_API_LIB':'@XPU_API_LIB@', - 'XPU_API_LIB_NAME':'@XPU_API_LIB_NAME@', - 'XPU_RT_LIB':'@XPU_RT_LIB@', - 'XPU_RT_LIB_NAME':'@XPU_RT_LIB_NAME@', - 'WITH_XPU_BKCL':'@WITH_XPU_BKCL@', - 'XPU_BKCL_LIB':'@XPU_BKCL_LIB@', - 'XPU_BKCL_LIB_NAME':'@XPU_BKCL_LIB_NAME@', - 'WITH_XPU_XFT':'@WITH_XPU_XFT@', - 'XPU_XFT_LIB':'@XPU_XFT_LIB@', - 'XPU_XFT_LIB_NAME':'@XPU_XFT_LIB_NAME@', - 'THIRD_PARTY_PATH':'@THIRD_PARTY_PATH@', - 'SETUP_LOG_FILE':'@SETUP_LOG_FILE@', - 'WITH_STRIP':'@WITH_STRIP@', - 'PACKAGE_NAME':'@PACKAGE_NAME@', - 'PADDLE_VERSION':'@PADDLE_VERSION@', - 'APPLE':'@APPLE@', - 'externalError_INCLUDE_DIR':'@externalError_INCLUDE_DIR@', - 'WITH_ROCM':'@WITH_ROCM@', - 'ORIGIN':'@ORIGIN@', - 'WIN32':'@WIN32@', - 'JIT_RELEASE_WHL':'@JIT_RELEASE_WHL@', - 'WITH_PSLIB':'@WITH_PSLIB@', - 'PYBIND_INCLUDE_DIR':'@PYBIND_INCLUDE_DIR@', - 'WITH_PYTHON':'@WITH_PYTHON@', - 'WITH_CINN':'@WITH_CINN@', - 'CINN_SOURCE_DIR':'@CINN_SOURCE_DIR@', - 'WITH_CPP_DIST':'@WITH_CPP_DIST@', - 'PADDLE_INSTALL_DIR':'@PADDLE_INSTALL_DIR@', - 'PADDLE_LIB_TEST_DIR':'@PADDLE_LIB_TEST_DIR@' -} - -global paddle_binary_dir, paddle_source_dir - -paddle_binary_dir = env_dict.get("PADDLE_BINARY_DIR") -paddle_source_dir = env_dict.get("PADDLE_SOURCE_DIR") +print("Paddle Include Path:", paddle_include_path) +print("Paddle Lib Path:", paddle_lib_path) # preparing parameters for setup() -paddle_version = env_dict.get("PADDLE_VERSION") -package_name = env_dict.get("PACKAGE_NAME") +paddle_version = paddle.version.full_version +cuda_version= paddle.version.cuda_version -import urllib.request -import urllib.error -from wheel.bdist_wheel import bdist_wheel as _bdist_wheel - -import paddle -from paddle.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension -from paddle.utils.cpp_extension.extension_utils import find_cuda_home with open("../README.md", "r", encoding="utf-8") as fh: long_description = fh.read() @@ -153,14 +64,11 @@ # ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) CUDA_HOME = find_cuda_home() -PACKAGE_NAME = "flash_attn" +PACKAGE_NAME = "paddle_flash_attn" -BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" +BASE_WHEEL_URL = "https://github.com/PaddlePaddle/flash-attention/releases/download/{tag_name}/{wheel_name}" -# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels -# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" -SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" # For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" # For CI, we want the option to not add "--threads 4" to nvcc, since the runner can OOM @@ -218,12 +126,6 @@ def raise_if_cuda_home_none(global_option: str) -> None: "only images whose names contain 'devel' will provide nvcc." ) - -def append_nvcc_threads(nvcc_extra_args): - if not FORCE_SINGLE_THREAD: - return nvcc_extra_args + ["--threads", "4"] - return nvcc_extra_args - def _is_cuda_available(): """ Check whether CUDA is available. @@ -261,7 +163,6 @@ def _is_cuda_available(): os.environ["PADDLE_CUDA_ARCH_LIST"] = "8.0;8.6" cmdclass = {} -ext_modules = [] def get_package_version(): with open(Path(this_dir) / "../flash_attn" / "__init__.py", "r") as f: @@ -294,49 +195,32 @@ class CachedWheelsCommand(_bdist_wheel): wheel available and short-circuits the standard full build pipeline. """ def run(self): - if FORCE_BUILD: - return super().run() - + print("88888888888888888888888888888") + # if FORCE_BUILD: + # return super().run() + self.run_command('build_ext') + super().run() # Determine the version numbers that will be used to determine the correct wheel # We're using the CUDA version used to build paddle, not the one currently installed # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) - paddle_cuda_version = parse(paddle.version.cuda) + paddle_cuda_version = "234" #parse(paddle.version.cuda) paddle_version_raw = parse(paddle.__version__) python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() flash_version = get_package_version() - # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" - cuda_version = f"{paddle_cuda_version.major}{paddle_cuda_version.minor}" - paddle_version = f"{paddle_version_raw.major}.{paddle_version_raw.minor}" - cxx11_abi = str(paddle._C._GLIBCXX_USE_CXX11_ABI).upper() + cxx11_abi ="" # str(paddle._C.-D_GLIBCXX_USE_CXX11_ABI).upper() # Determine wheel URL based on CUDA version, paddle version, python version and OS - wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}paddle{paddle_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl' - wheel_url = BASE_WHEEL_URL.format( - tag_name=f"v{flash_version}", - wheel_name=wheel_filename - ) - print("Guessing wheel URL: ", wheel_url) - - try: - urllib.request.urlretrieve(wheel_url, wheel_filename) - - # Make the archive - # Lifted from the root wheel processing command - # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 - if not os.path.exists(self.dist_dir): - os.makedirs(self.dist_dir) - - impl_tag, abi_tag, plat_tag = self.get_tag() - archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + wheel_filename = f'{PACKAGE_NAME}-{flash_version}-cu{cuda_version}-paddle{paddle_version}-{python_version}-{python_version}-{platform_name}.whl' + impl_tag, abi_tag, plat_tag = self.get_tag() + original_wheel_name = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" - wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") - print("Raw wheel path", wheel_path) - os.rename(wheel_filename, wheel_path) - except urllib.error.HTTPError: - print("Precompiled wheel not found. Building from source...") - # If the wheel could not be downloaded, build from source - super().run() + new_wheel_name = wheel_filename + print("self.asdfasdfsdfasdfasdfasdf", self.get_tag()) + shutil.move( + f"{self.dist_dir}/{original_wheel_name}.whl", + f"{self.dist_dir}/{new_wheel_name}" + ) class InstallHeaders(Command): """Override how headers are copied.""" @@ -497,20 +381,8 @@ def run(self): "License :: OSI Approved :: BSD License", "Operating System :: Unix", ], - ext_modules=ext_modules, cmdclass={ - 'install_headers': InstallHeaders, - 'install': InstallCommand, - 'egg_info': EggInfo, - 'install_lib': InstallLib, - 'develop': DevelopCommand, - }, - #cmdclass={ - # "bdist_wheel": CachedWheelsCommand, - # "build_ext": BuildExtension - #} if ext_modules else { - # "bdist_wheel": CachedWheelsCommand, - #}, + "bdist_wheel": CachedWheelsCommand,}, python_requires=">=3.7", install_requires=[ "paddle", From d536119e240ac62eb6b8b71c872115a49200c11e Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 18:53:36 +0800 Subject: [PATCH 08/37] update --- csrc/setup.py | 178 ++------------------------------------------------ 1 file changed, 7 insertions(+), 171 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index f2d7357cecf..dcfc4e6915b 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -141,27 +141,6 @@ def _is_cuda_available(): ) return False -if paddle.is_compiled_with_cuda() and _is_cuda_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pypaddle/pypaddle/pull/23408 attempt to query paddle.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, FlashAttention will cross-compile for Ampere (compute capability 8.0, 8.6, " - "8.9), and, if the CUDA version is >= 11.8, Hopper (compute capability 9.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export PADDLE_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("PADDLE_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["PADDLE_CUDA_ARCH_LIST"] = "8.0;8.6;9.0" - elif bare_metal_version >= Version("11.4"): - os.environ["PADDLE_CUDA_ARCH_LIST"] = "8.0;8.6" - else: - os.environ["PADDLE_CUDA_ARCH_LIST"] = "8.0;8.6" - cmdclass = {} def get_package_version(): @@ -183,7 +162,7 @@ def get_data_files(): # Specify the destination directory within the package destination_lib_path = os.path.join(PACKAGE_NAME, 'build/libflashattn.so') - data_files.append((os.path.join(PACKAGE_NAME, 'libflashattn.so'), [source_lib_path])) + data_files.append((paddle_lib_path, [source_lib_path])) return data_files @@ -215,160 +194,17 @@ def run(self): impl_tag, abi_tag, plat_tag = self.get_tag() original_wheel_name = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" - new_wheel_name = wheel_filename - print("self.asdfasdfsdfasdfasdfasdf", self.get_tag()) - shutil.move( - f"{self.dist_dir}/{original_wheel_name}.whl", - f"{self.dist_dir}/{new_wheel_name}" - ) - -class InstallHeaders(Command): - """Override how headers are copied.""" - - description = 'install C/C++ header files' - - user_options = [ - ('install-dir=', 'd', 'directory to install header files to'), - ('force', 'f', 'force installation (overwrite existing files)'), - ] - - boolean_options = ['force'] - - def initialize_options(self): - self.install_dir = None - self.force = 0 - self.outfiles = [] - - def finalize_options(self): - self.set_undefined_options( - 'install', ('install_headers', 'install_dir'), ('force', 'force') - ) - - def run(self): - hdrs = self.distribution.headers - if not hdrs: - return - self.mkpath(self.install_dir) - for header in hdrs: - install_dir = get_header_install_dir(header) - install_dir = os.path.join( - self.install_dir, os.path.dirname(install_dir) - ) - if not os.path.exists(install_dir): - self.mkpath(install_dir) - (out, _) = self.copy_file(header, install_dir) - self.outfiles.append(out) - # (out, _) = self.mkdir_and_copy_file(header) - # self.outfiles.append(out) - - def get_inputs(self): - return self.distribution.headers or [] - - def get_outputs(self): - return self.outfiles - - -class InstallCommand(InstallCommandBase): - def finalize_options(self): - ret = InstallCommandBase.finalize_options(self) - self.install_lib = self.install_platlib - - self.install_headers = os.path.join( - self.install_platlib, 'paddle', 'include' - ) - return ret - - -class DevelopCommand(DevelopCommandBase): - def run(self): - # copy proto and .so to python_source_dir - fluid_proto_binary_path = ( - paddle_binary_dir + '/python/paddle/base/proto/' - ) - fluid_proto_source_path = ( - paddle_source_dir + '/python/paddle/base/proto/' - ) - distributed_proto_binary_path = ( - paddle_binary_dir + '/python/paddle/distributed/fleet/proto/' - ) - distributed_proto_source_path = ( - paddle_source_dir + '/python/paddle/distributed/fleet/proto/' - ) - os.system(f"rm -rf {fluid_proto_source_path}") - shutil.copytree(fluid_proto_binary_path, fluid_proto_source_path) - os.system(f"rm -rf {distributed_proto_source_path}") - shutil.copytree( - distributed_proto_binary_path, distributed_proto_source_path - ) - shutil.copy( - paddle_binary_dir + '/python/paddle/base/libpaddle.so', - paddle_source_dir + '/python/paddle/base/', - ) - dynamic_library_binary_path = paddle_binary_dir + '/python/paddle/libs/' - dynamic_library_source_path = paddle_source_dir + '/python/paddle/libs/' - for lib_so in os.listdir(dynamic_library_binary_path): - shutil.copy( - dynamic_library_binary_path + lib_so, - dynamic_library_source_path, - ) - # write version.py and cuda_env_config_py to python_source_dir - write_version_py( - filename=f'{paddle_source_dir}/python/paddle/version/__init__.py' - ) - write_cuda_env_config_py( - filename=f'{paddle_source_dir}/python/paddle/cuda_env.py' - ) - write_parameter_server_version_py( - filename='{}/python/paddle/incubate/distributed/fleet/parameter_server/version.py'.format( - paddle_source_dir - ) - ) - DevelopCommandBase.run(self) - - -class EggInfo(egg_info): - """Copy license file into `.dist-info` folder.""" - - def run(self): - # don't duplicate license into `.dist-info` when building a distribution - if not self.distribution.have_run.get('install', True): - self.mkpath(self.egg_info) - #self.copy_file( - # env_dict.get("PADDLE_SOURCE_DIR") + "/LICENSE", self.egg_info - #) - - egg_info.run(self) - - -# class Installlib is rewritten to add header files to .egg/paddle -class InstallLib(install_lib): - def run(self): - self.build() - outfiles = self.install() - hrds = self.distribution.headers - if not hrds: - return - for header in hrds: - install_dir = get_header_install_dir(header) - install_dir = os.path.join( - self.install_dir, 'paddle/include', os.path.dirname(install_dir) - ) - if not os.path.exists(install_dir): - self.mkpath(install_dir) - self.copy_file(header, install_dir) - if outfiles is not None: - # always compile, in case we have any extension stubs to deal with - self.byte_compile(outfiles) - + new_wheel_name ='asdfsdf.whl' # wheel_filename + #shutil.move( + # f"{self.dist_dir}/{original_wheel_name}.whl", + # f"{self.dist_dir}/{new_wheel_name}" + #) setup( name=PACKAGE_NAME, version=get_package_version(), - packages=find_packages( - #exclude=("build") - #, "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) - ), + packages=find_packages(), data_files=get_data_files(), package_data={PACKAGE_NAME: ['build/libflashattn.so']}, author_email="Paddle-better@baidu.com", From 66fc8a754ceee446a05964bbf99746536fd82c3f Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:00:56 +0800 Subject: [PATCH 09/37] update --- csrc/CMakeLists.txt | 5 +- csrc/README.md | 234 ------------ .../src/flash_bwd_launch_template.h | 17 + .../src/flash_fwd_launch_template.h | 22 ++ csrc/mp.py | 336 ------------------ csrc/tp.py | 77 ---- 6 files changed, 43 insertions(+), 648 deletions(-) delete mode 100644 csrc/README.md delete mode 100644 csrc/mp.py delete mode 100644 csrc/tp.py diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 3b84b53889c..9195ca56528 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -6,7 +6,10 @@ find_package(Git QUIET REQUIRED) execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} RESULT_VARIABLE GIT_SUBMOD_RESULT) - +#cmake -DWITH_ADVANCED +if (WITH_ADVANCED) + add_compile_definitions(PADDLE_WITH_ADVANCED)cu +endif() add_definitions("-DFLASH_ATTN_WITH_TORCH=0") set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) diff --git a/csrc/README.md b/csrc/README.md deleted file mode 100644 index 79d33453003..00000000000 --- a/csrc/README.md +++ /dev/null @@ -1,234 +0,0 @@ -# FlashAttention -This repository provides the official implementation of FlashAttention and -FlashAttention-2 from the -following papers. - -**FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness** -Tri Dao, Daniel Y. Fu, Stefano Ermon, Atri Rudra, Christopher Ré -Paper: https://arxiv.org/abs/2205.14135 -IEEE Spectrum [article](https://spectrum.ieee.org/mlperf-rankings-2022) about our submission to the MLPerf 2.0 benchmark using FlashAttention. -![FlashAttention](assets/flashattn_banner.jpg) - -**FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning** -Tri Dao - -Paper: https://tridao.me/publications/flash2/flash2.pdf - -![FlashAttention-2](assets/flashattention_logo.png) - - -## Usage - -We've been very happy to see FlashAttention being widely adopted in such a short -time after its release. This [page](https://github.com/Dao-AILab/flash-attention/blob/main/usage.md) -contains a partial list of places where FlashAttention is being used. - -FlashAttention and FlashAttention-2 are free to use and modify (see LICENSE). -Please cite and credit FlashAttention if you use it. - -## Installation and features - -Requirements: -- CUDA 11.4 and above. -- PyTorch 1.12 and above. - -We recommend the -[Pytorch](https://catalog.ngc.nvidia.com/orgs/nvidia/containers/pytorch) -container from Nvidia, which has all the required tools to install FlashAttention. - -To install: -1. Make sure that PyTorch is installed. -2. Make sure that `packaging` is installed (`pip install packaging`) -3. Make sure that `ninja` is installed and that it works correctly (e.g. `ninja ---version` then `echo $?` should return exit code 0). If not (sometimes `ninja ---version` then `echo $?` returns a nonzero exit code), uninstall then reinstall -`ninja` (`pip uninstall -y ninja && pip install ninja`). Without `ninja`, -compiling can take a very long time (2h) since it does not use multiple CPU -cores. With `ninja` compiling takes 3-5 minutes on a 64-core machine. -4. Then: -```sh -pip install flash-attn --no-build-isolation -``` -Alternatively you can compile from source: -```sh -python setup.py install -``` - -If your machine has less than 96GB of RAM and lots of CPU cores, `ninja` might -run too many parallel compilation jobs that could exhaust the amount of RAM. To -limit the number of parallel compilation jobs, you can set the environment -variable `MAX_JOBS`: -```sh -MAX_JOBS=4 pip install flash-attn --no-build-isolation -``` - -Interface: `src/flash_attention_interface.py` - -FlashAttention-2 currently supports: -1. Ampere, Ada, or Hopper GPUs (e.g., A100, RTX 3090, RTX 4090, H100). Support for Turing - GPUs (T4, RTX 2080) is coming soon, please use FlashAttention 1.x for Turing - GPUs for now. -2. Datatype fp16 and bf16 (bf16 requires Ampere, Ada, or Hopper GPUs). -3. All head dimensions up to 256. Head dim > 192 backward requires A100/A800 or H100/H800. - - -## How to use FlashAttention - -The main functions implement scaled dot product attention (softmax(Q @ K^T * -softmax_scale) @ V): -```python -from flash_attn import flash_attn_qkvpacked_func, flash_attn_func -``` - -```python -flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False): -"""dropout_p should be set to 0.0 during evaluation -If Q, K, V are already stacked into 1 tensor, this function will be faster than -calling flash_attn_func on Q, K, V since the backward pass avoids explicit concatenation -of the gradients of Q, K, V. -Arguments: - qkv: (batch_size, seqlen, 3, nheads, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). -Return: - out: (batch_size, seqlen, nheads, headdim). -""" -``` - -```python -flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False): -"""dropout_p should be set to 0.0 during evaluation -Supports multi-query and grouped-query attention (MQA/GQA) by passing in KV with fewer heads -than Q. Note that the number of heads in Q must be divisible by the number of heads in KV. -For example, if Q has 6 heads and K, V have 2 heads, head 0, 1, 2 of Q will attention to head -0 of K, V, and head 3, 4, 5 of Q will attention to head 1 of K, V. - -Arguments: - q: (batch_size, seqlen, nheads, headdim) - k: (batch_size, seqlen, nheads_k, headdim) - v: (batch_size, seqlen, nheads_k, headdim) - dropout_p: float. Dropout probability. - softmax_scale: float. The scaling of QK^T before applying softmax. - Default to 1 / sqrt(headdim). - causal: bool. Whether to apply causal attention mask (e.g., for auto-regressive modeling). -Return: - out: (batch_size, seqlen, nheads, headdim). -""" -``` - -To see how these functions are used in a multi-head attention layer (which -includes QKV projection, output projection), see the MHA [implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/modules/mha.py). - -## Upgrading from FlashAttention (1.x) to FlashAttention-2 - -These functions have been renamed: -- `flash_attn_unpadded_func` -> `flash_attn_varlen_func` -- `flash_attn_unpadded_qkvpacked_func` -> `flash_attn_varlen_qkvpacked_func` -- `flash_attn_unpadded_kvpacked_func` -> `flash_attn_varlen_kvpacked_func` - -If the inputs have the same sequence lengths in the same batch, it is simpler -and faster to use these functions: -```python -flash_attn_qkvpacked_func(qkv, dropout_p=0.0, softmax_scale=None, causal=False) -``` -```python -flash_attn_func(q, k, v, dropout_p=0.0, softmax_scale=None, causal=False) -``` - -## Performance - -We present expected speedup (combined forward + backward pass) and memory savings from using FlashAttention against PyTorch standard attention, depending on sequence length, on different GPUs (speedup depends on memory bandwidth - we see more speedup on slower GPU memory). - -We currently have benchmarks for these GPUs: -* [A100](#a100) -* [H100](#h100) - - - -### A100 - -We display FlashAttention speedup using these parameters: -* Head dimension 64 or 128, hidden dimension 2048 (i.e. either 32 or 16 heads). -* Sequence length 512, 1k, 2k, 4k, 8k, 16k. -* Batch size set to 16k / seqlen. - -#### Speedup - -![FlashAttention speedup on A100 80GB SXM5 with FP16/BF16](assets/flash2_a100_fwd_bwd_benchmark.png) - -#### Memory - -![FlashAttention memory](assets/flashattn_memory.jpg) - -We show memory savings in this graph (note that memory footprint is the same no matter if you use dropout or masking). -Memory savings are proportional to sequence length -- since standard attention has memory quadratic in sequence length, whereas FlashAttention has memory linear in sequence length. -We see 10X memory savings at sequence length 2K, and 20X at 4K. -As a result, FlashAttention can scale to much longer sequence lengths. - -### H100 - -![FlashAttention speedup on H100 SXM5 with FP16/BF16](assets/flash2_h100_fwd_bwd_benchmark.png) - -## Full model code and training script - -We have released the full GPT model -[implementation](https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/models/gpt.py). -We also provide optimized implementations of other layers (e.g., MLP, LayerNorm, -cross-entropy loss, rotary embedding). Overall this speeds up training by 3-5x -compared to the baseline implementation from Huggingface, reaching up to 225 -TFLOPs/sec per A100, equivalent to 72% model FLOPs utilization (we don't need -any activation checkpointing). - -We also include a training -[script](https://github.com/Dao-AILab/flash-attention/tree/main/training) to -train GPT2 on Openwebtext and GPT3 on The Pile. - -## Triton implementation of FlashAttention - -Phil Tillet (OpenAI) has an experimental implementation of FlashAttention in Triton: -https://github.com/openai/triton/blob/master/python/tutorials/06-fused-attention.py - -As Triton is a higher-level language than CUDA, it might be easier to understand -and experiment with. The notations in the Triton implementation are also closer -to what's used in our paper. - -We also have an experimental implementation in Triton that support attention -bias (e.g. ALiBi): -https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/flash_attn_triton.py - - -## Tests -We test that FlashAttention produces the same output and gradient as a reference -implementation, up to some numerical tolerance. In particular, we check that the -maximum numerical error of FlashAttention is at most twice the numerical error -of a baseline implementation in Pytorch (for different head dimensions, input -dtype, sequence length, causal / non-causal). - -To run the tests: -```sh -pytest -q -s tests/test_flash_attn.py -``` -## When you encounter issues - -This new release of FlashAttention-2 has been tested on several GPT-style -models, mostly on A100 GPUs. - -If you encounter bugs, please open a GitHub Issue! - -## Citation -If you use this codebase, or otherwise found our work valuable, please cite: -``` -@inproceedings{dao2022flashattention, - title={Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness}, - author={Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher}, - booktitle={Advances in Neural Information Processing Systems}, - year={2022} -} -@article{dao2023flashattention2, - title={Flash{A}ttention-2: Faster Attention with Better Parallelism and Work Partitioning, - author={Dao, Tri}, - year={2023} -} -``` diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 2c62e6c5797..f3eb8850f0d 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -64,6 +64,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool is_attn_mask = params.attn_mask_ptr != nullptr; const bool is_deterministic = params.num_splits == 1; // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); +#ifdef PADDLE_WITH_ADVANCED BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { @@ -82,6 +83,22 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, }); }); }); +#else + BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); +#endif auto kernel_dq = &flash_bwd_convert_dq_kernel; if (Kernel_traits::kSmemdQSize >= 48 * 1024) { diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 6c638261766..5090605cb56 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -36,6 +36,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const bool return_softmax = params.p_ptr != nullptr; const bool is_attn_mask = params.attn_mask_ptr != nullptr; const bool is_equal_qk = (params.seqlen_q == params.seqlen_k) && (Is_causal) && (!is_attn_mask); +#ifdef PADDLE_WITH_ADVANCED BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { @@ -59,6 +60,27 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }); }); }); +#else + BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + // Will only return softmax if dropout, to reduce compilation time. + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); +#endif } template diff --git a/csrc/mp.py b/csrc/mp.py deleted file mode 100644 index ab8435e8e71..00000000000 --- a/csrc/mp.py +++ /dev/null @@ -1,336 +0,0 @@ -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -import re -import ast -from pathlib import Path -from packaging.version import parse, Version -import platform - -from setuptools import setup, find_packages -import subprocess - -import urllib.request -import urllib.error -from wheel.bdist_wheel import bdist_wheel as _bdist_wheel - -import paddle -from paddle.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension -from paddle.utils.cpp_extension.extension_utils import find_cuda_home - -with open("README.md", "r", encoding="utf-8") as fh: - long_description = fh.read() - - -# ninja build does not work unless include_dirs are abs path -this_dir = os.path.dirname(os.path.abspath(__file__)) -CUDA_HOME = find_cuda_home() -PACKAGE_NAME = "flash_attn" - -BASE_WHEEL_URL = "https://github.com/Dao-AILab/flash-attention/releases/download/{tag_name}/{wheel_name}" - -# FORCE_BUILD: Force a fresh build locally, instead of attempting to find prebuilt wheels -# SKIP_CUDA_BUILD: Intended to allow CI to use a simple `python setup.py sdist` run to copy over raw files, without any cuda compilation -FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" -SKIP_CUDA_BUILD = os.getenv("FLASH_ATTENTION_SKIP_CUDA_BUILD", "FALSE") == "TRUE" -# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI -FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" -# For CI, we want the option to not add "--threads 4" to nvcc, since the runner can OOM -FORCE_SINGLE_THREAD = os.getenv("FLASH_ATTENTION_FORCE_SINGLE_THREAD", "FALSE") == "TRUE" - - -def get_platform(): - """ - Returns the platform name as used in wheel filenames. - """ - if sys.platform.startswith('linux'): - return 'linux_x86_64' - elif sys.platform == 'darwin': - mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2]) - return f'macosx_{mac_version}_x86_64' - elif sys.platform == 'win32': - return 'win_amd64' - else: - raise ValueError('Unsupported platform: {}'.format(sys.platform)) - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def check_cuda_paddle_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - paddle_binary_version = parse(paddle.version.cuda()) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != paddle_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pypaddle binaries. " - "Pypaddle binaries were compiled with Cuda {}.\n".format(paddle.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pypaddle/pypaddle, " - "only images whose names contain 'devel' will provide nvcc." - ) - - -def append_nvcc_threads(nvcc_extra_args): - if not FORCE_SINGLE_THREAD: - return nvcc_extra_args + ["--threads", "4"] - return nvcc_extra_args - -def _is_cuda_available(): - """ - Check whether CUDA is available. - """ - try: - assert len(paddle.static.cuda_places()) > 0 - return True - except Exception as e: - logging.warning( - "You are using GPU version PaddlePaddle, but there is no GPU " - "detected on your machine. Maybe CUDA devices is not set properly." - f"\n Original Error is {e}" - ) - return False - -if paddle.is_compiled_with_cuda() and _is_cuda_available(): - # https://github.com/NVIDIA/apex/issues/486 - # Extension builds after https://github.com/pypaddle/pypaddle/pull/23408 attempt to query paddle.cuda.get_device_capability(), - # which will fail if you are compiling in an environment without visible GPUs (e.g. during an nvidia-docker build command). - print( - "\nWarning: Torch did not find available GPUs on this system.\n", - "If your intention is to cross-compile, this is not an error.\n" - "By default, FlashAttention will cross-compile for Ampere (compute capability 8.0, 8.6, " - "8.9), and, if the CUDA version is >= 11.8, Hopper (compute capability 9.0).\n" - "If you wish to cross-compile for a single specific architecture,\n" - 'export TORCH_CUDA_ARCH_LIST="compute capability" before running setup.py.\n', - ) - if os.environ.get("TORCH_CUDA_ARCH_LIST", None) is None and CUDA_HOME is not None: - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version >= Version("11.8"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6;9.0" - elif bare_metal_version >= Version("11.4"): - os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" - else: - os.environ["TORCH_CUDA_ARCH_LIST"] = "8.0;8.6" - -cmdclass = {} -ext_modules = [] - -# We want this even if SKIP_CUDA_BUILD because when we run python setup.py sdist we want the .hpp -# files included in the source distribution, in case the user compiles from source. -subprocess.run(["git", "submodule", "update", "--init", "csrc/cutlass"]) - -if not SKIP_CUDA_BUILD: - print("\n\npaddle.__version__ = {}\n\n".format(paddle.__version__)) - TORCH_MAJOR = int(paddle.__version__.split(".")[0]) - TORCH_MINOR = int(paddle.__version__.split(".")[1]) - - # Check, if ATen/CUDAGeneratorImpl.h is found, otherwise use ATen/cuda/CUDAGeneratorImpl.h - # See https://github.com/pypaddle/pypaddle/pull/70650 - generator_flag = [] - paddle_dir = paddle.__path__[0] - if os.path.exists(os.path.join(paddle_dir, "include", "ATen", "CUDAGeneratorImpl.h")): - generator_flag = ["-DOLD_GENERATOR_PATH"] - - raise_if_cuda_home_none("flash_attn") - # Check, if CUDA11 is installed for compute capability 8.0 - cc_flag = [] - _, bare_metal_version = get_cuda_bare_metal_version(CUDA_HOME) - if bare_metal_version < Version("11.4"): - raise RuntimeError("FlashAttention is only supported on CUDA 11.4 and above") - # cc_flag.append("-gencode") - # cc_flag.append("arch=compute_75,code=sm_75") - cc_flag.append("-gencode") - cc_flag.append("arch=compute_80,ggcode=sm_80") - if bare_metal_version >= Version("11.8"): - cc_flag.append("-gencode") - cc_flag.append("arch=compute_90,code=sm_90") - - # HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as - # paddle._C._GLIBCXX_USE_CXX11_ABI - # https://github.com/pypaddle/pypaddle/blob/8472c24e3b5b60150096486616d98b7bea01500b/paddle/utils/cpp_extension.py#L920 - if FORCE_CXX11_ABI: - paddle._C._GLIBCXX_USE_CXX11_ABI = True - ext_modules.append( - CUDAExtension( - sources=[ - "csrc/flash_attn/flash_api.cpp", - "csrc/flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", - "csrc/flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", - "csrc/flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", - ], - extra_compile_args={ - "cxx": ["-O3", "-std=c++17"] + generator_flag, - "nvcc": append_nvcc_threads( - [ - "-O3", - "-std=c++17", - "-U__CUDA_NO_HALF_OPERATORS__", - "-U__CUDA_NO_HALF_CONVERSIONS__", - "-U__CUDA_NO_HALF2_OPERATORS__", - "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", - "--expt-relaxed-constexpr", - "--expt-extended-lambda", - "--use_fast_math", - "--ptxas-options=-v", - # "--ptxas-options=-O2", - "-lineinfo" - ] - + generator_flag - + cc_flag - ), - }, - include_dirs=[ - Path(this_dir) / 'csrc' / 'flash_attn', - Path(this_dir) / 'csrc' / 'flash_attn' / 'src', - Path(this_dir) / 'csrc' / 'cutlass' / 'include', - ], - ) - ) - - -def get_package_version(): - with open(Path(this_dir) / "../flash_attn" / "__init__.py", "r") as f: - version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) - public_version = ast.literal_eval(version_match.group(1)) - local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") - if local_version: - return f"{public_version}+{local_version}" - else: - return str(public_version) - - -class CachedWheelsCommand(_bdist_wheel): - """ - The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot - find an existing wheel (which is currently the case for all flash attention installs). We use - the environment parameters to detect whether there is already a pre-built version of a compatible - wheel available and short-circuits the standard full build pipeline. - """ - def run(self): - if FORCE_BUILD: - return super().run() - - # Determine the version numbers that will be used to determine the correct wheel - # We're using the CUDA version used to build paddle, not the one currently installed - # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) - paddle_cuda_version = parse(paddle.version.cuda) - paddle_version_raw = parse(paddle.__version__) - python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" - platform_name = get_platform() - flash_version = get_package_version() - # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" - cuda_version = f"{paddle_cuda_version.major}{paddle_cuda_version.minor}" - paddle_version = f"{paddle_version_raw.major}.{paddle_version_raw.minor}" - cxx11_abi = str(paddle._C._GLIBCXX_USE_CXX11_ABI).upper() - - # Determine wheel URL based on CUDA version, paddle version, python version and OS - wheel_filename = f'{PACKAGE_NAME}-{flash_version}+cu{cuda_version}paddle{paddle_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl' - wheel_url = BASE_WHEEL_URL.format( - tag_name=f"v{flash_version}", - wheel_name=wheel_filename - ) - print("Guessing wheel URL: ", wheel_url) - - try: - urllib.request.urlretrieve(wheel_url, wheel_filename) - - # Make the archive - # Lifted from the root wheel processing command - # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 - if not os.path.exists(self.dist_dir): - os.makedirs(self.dist_dir) - - impl_tag, abi_tag, plat_tag = self.get_tag() - archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" - - wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") - print("Raw wheel path", wheel_path) - os.rename(wheel_filename, wheel_path) - except urllib.error.HTTPError: - print("Precompiled wheel not found. Building from source...") - # If the wheel could not be downloaded, build from source - super().run() - - -setup( - name=PACKAGE_NAME, - version=get_package_version(), - packages=find_packages( - exclude=("build", "csrc", "include", "tests", "dist", "docs", "benchmarks", "flash_attn.egg-info",) - ), - author="Tri Dao", - author_email="trid@cs.stanford.edu", - description="Flash Attention: Fast and Memory-Efficient Exact Attention", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/Dao-AILab/flash-attention", - classifiers=[ - "Programming Language :: Python :: 3", - "License :: OSI Approved :: BSD License", - "Operating System :: Unix", - ], - #ext_modules=ext_modules, - cmdclass={ - 'bdist_wheel': CachedWheelsCommand, - "build_ext": BuildExtension - } if ext_modules else { - 'bdist_wheel': CachedWheelsCommand, - }, - python_requires=">=3.7", - install_requires=[ - "paddle", - "einops", - "packaging", - "ninja", - ], -) diff --git a/csrc/tp.py b/csrc/tp.py deleted file mode 100644 index 56aebee074d..00000000000 --- a/csrc/tp.py +++ /dev/null @@ -1,77 +0,0 @@ -import paddle -from setuptools import setup, find_packages -import sys -import os -import paddle -paddle_path = paddle.sysconfig.get_lib -print(paddle_path) -python_version = sys.version -print("Installing your_package...") - -# Get the CUDA version from PaddlePaddle -cuda_version = paddle.version.cuda() -fa_version = f"1.0.0.post{cuda_version}" -package_name = 'flash_attention_paddle_gpu' - -def get_data_files(): - data_files = [] - - # Assuming 'libflashattn.so' is located in the same directory as setup.py - source_lib_path = 'libflashattn.so' - - # Specify the destination directory within the package - destination_lib_path = os.path.join(package_name, 'libflashattn.so') - - data_files.append((os.path.join(package_name, 'libflashattn.so'), [source_lib_path])) - print(destination_lib_path, "asdf ****************") - print(data_files) - return data_files - -setup( - name=package_name, - version=fa_version, - data_files=get_data_files(), - description='Flash attention in paddlepaddle', - packages=find_packages(), - package_data={package_name: ['build/libflashattn.so']}, -) -# -#import paddle -#import os -#from setuptools import setup -#import sys -# -#python_version = sys.version -#print("Installing your_package...") -# -## Get the CUDA version from PaddlePaddle -#cuda_version = paddle.version.cuda() -#fa_version = f"1.0.0.post{cuda_version}" -#package_name = 'flash_attention_paddle_gpu' # Adjusted package name -# -#def get_data_files(): -# data_files = [] -# -# # Assuming 'libflashattn.so' is located in the same directory as setup.py -# source_lib_path = os.path.abspath('libflashattn.so') -# -# # Specify the destination directory within the package -# destination_lib_path = os.path.join(package_name, 'libflashattn.so') -# -# data_files.append((os.path.join(package_name, 'libflashattn.so'), [source_lib_path])) -# print(destination_lib_path, "asdf ****************") -# print(data_files) -# return data_files -# -## Create an empty __init__.py file in the package directory -#init_file_path = os.path.join(package_name, '__init__.py') -#with open(init_file_path, 'w') as f: -# pass -# -#setup( -# name=package_name, -# version=fa_version, -# description='Flash attention in paddlepaddle', -# packages=[package_name], -# package_data={package_name: ['libflashattn.so']}, -#) From 41ebd074fb8942a2cefe7f1f318945ab8176475a Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:05:05 +0800 Subject: [PATCH 10/37] all --- csrc/CMakeLists.txt | 80 ++++++++++++++++++++++----------------------- 1 file changed, 40 insertions(+), 40 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 9195ca56528..a051f0552f3 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -16,38 +16,38 @@ set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) set(FA2_SOURCES_CU flash_attn/src/cuda_utils.cu - #flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu ) add_library(flashattn SHARED @@ -60,13 +60,13 @@ target_include_directories(flashattn PRIVATE set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu - #flash_attn_with_bias_and_mask/src/cuda_utils.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu + flash_attn_with_bias_and_mask/src/cuda_utils.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu flash_attn_with_bias_and_mask/src/utils.cu) add_library(flashattn_with_bias_mask STATIC @@ -141,7 +141,7 @@ INSTALL(TARGETS flashattn INSTALL(FILES capi/flash_attn.h DESTINATION "include") add_custom_target(run_my_executable - COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py sdist bdist_wheel + COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel WORKING_DIRECTORY ${CMAKE_BINARY_DIR} DEPENDS flashattn COMMENT "Running my_executable" From ad614e0dfcaa8862eedb59506d65b2ed8f6a5a39 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:06:10 +0800 Subject: [PATCH 11/37] update --- csrc/yest | 3 --- 1 file changed, 3 deletions(-) delete mode 100644 csrc/yest diff --git a/csrc/yest b/csrc/yest deleted file mode 100644 index b3d4d3cd0b5..00000000000 --- a/csrc/yest +++ /dev/null @@ -1,3 +0,0 @@ -include build/libflashattn.so -include src/libflashattn.so -include ./libflashattn.so From c8d003ab709aec89ada749677d2c3e4c014d860b Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:07:16 +0800 Subject: [PATCH 12/37] update --- yes.py | 12 ------------ 1 file changed, 12 deletions(-) delete mode 100644 yes.py diff --git a/yes.py b/yes.py deleted file mode 100644 index 29917c43ecc..00000000000 --- a/yes.py +++ /dev/null @@ -1,12 +0,0 @@ -from setuptools import setup - -package_name = '' #flash-attention-paddle-gpu' -setup( - name=package_name, - version='1.0.0', - description='Flash attention in PaddlePaddle', - packages=[package_name], - include_package_data=True, - package_data={package_name: ['csrc/build/libflashattn.so']}, -) - From 559a47952de47cfd5aeecb66cdc306208757dec8 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:08:29 +0800 Subject: [PATCH 13/37] update --- csrc/CMakeLists.txt | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index a051f0552f3..bac78e5b438 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -134,7 +134,6 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$) -# INSTALL(TARGETS flashattn LIBRARY DESTINATION "lib") @@ -147,9 +146,8 @@ add_custom_target(run_my_executable COMMENT "Running my_executable" ) -# 创建一个伪目标作为默认构建目标 add_custom_target(default_target DEPENDS run_my_executable) -# 设置 'default_target' 为默认构建目标 +# set 'default_target' set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target) From bd670ae650df79b3b60b98d4de6be8fd69f00d41 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:10:16 +0800 Subject: [PATCH 14/37] 80 90 --- csrc/CMakeLists.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index bac78e5b438..27ac1fdb692 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -6,10 +6,12 @@ find_package(Git QUIET REQUIRED) execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} RESULT_VARIABLE GIT_SUBMOD_RESULT) + #cmake -DWITH_ADVANCED if (WITH_ADVANCED) add_compile_definitions(PADDLE_WITH_ADVANCED)cu endif() + add_definitions("-DFLASH_ATTN_WITH_TORCH=0") set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) From e4b500675954bf1a9f48189e512e1c00b73c902a Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:22:11 +0800 Subject: [PATCH 15/37] error --- csrc/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 27ac1fdb692..c98cdbb450b 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -9,7 +9,7 @@ execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive #cmake -DWITH_ADVANCED if (WITH_ADVANCED) - add_compile_definitions(PADDLE_WITH_ADVANCED)cu + add_compile_definitions(PADDLE_WITH_ADVANCED) endif() add_definitions("-DFLASH_ATTN_WITH_TORCH=0") From 4f7f1f0dfa0793fc24fe495bf55a0efa75986442 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:28:27 +0800 Subject: [PATCH 16/37] update build ok --- csrc/CMakeLists.txt | 79 +++++++++++++++++++++++---------------------- csrc/setup.py | 2 +- 2 files changed, 41 insertions(+), 40 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index c98cdbb450b..fbf48d4922e 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -18,38 +18,38 @@ set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) set(FA2_SOURCES_CU flash_attn/src/cuda_utils.cu - flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu ) add_library(flashattn SHARED @@ -62,13 +62,13 @@ target_include_directories(flashattn PRIVATE set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu - flash_attn_with_bias_and_mask/src/cuda_utils.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu + #flash_attn_with_bias_and_mask/src/cuda_utils.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu flash_attn_with_bias_and_mask/src/utils.cu) add_library(flashattn_with_bias_mask STATIC @@ -100,6 +100,7 @@ endif() STRING(REPLACE "-" ";" FA_NVCC_ARCH_BIN ${NVCC_ARCH_BIN}) set(FA_GENCODE_OPTION "SHELL:") + foreach(arch ${FA_NVCC_ARCH_BIN}) if(${arch} GREATER_EQUAL 80) set(FA_GENCODE_OPTION "${FA_GENCODE_OPTION} -gencode arch=compute_${arch},code=sm_${arch}") diff --git a/csrc/setup.py b/csrc/setup.py index dcfc4e6915b..d26d6136f3f 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -57,7 +57,7 @@ cuda_version= paddle.version.cuda_version -with open("../README.md", "r", encoding="utf-8") as fh: +with open("../../README.md", "r", encoding="utf-8") as fh: long_description = fh.read() From 8c12f7266657a1627f4331053de776b843e610ba Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 19:39:47 +0800 Subject: [PATCH 17/37] update --- csrc/setup.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index d26d6136f3f..b40d72c9a51 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -194,11 +194,12 @@ def run(self): impl_tag, abi_tag, plat_tag = self.get_tag() original_wheel_name = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" - new_wheel_name ='asdfsdf.whl' # wheel_filename - #shutil.move( - # f"{self.dist_dir}/{original_wheel_name}.whl", - # f"{self.dist_dir}/{new_wheel_name}" - #) + #new_wheel_name = wheel_filename + new_wheel_name = f"{self.wheel_dist_name}-{python_version}-{abi_tag}-{plat_tag}" + shutil.move( + f"{self.dist_dir}/{original_wheel_name}.whl", + f"{self.dist_dir}/{new_wheel_name}.whl" + ) setup( From 7b257e8ad308b30a5a98bd861b7d14e1f3a30a74 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 20:02:19 +0800 Subject: [PATCH 18/37] update --- csrc/setup.py | 29 ++++++++++++++++++++--------- 1 file changed, 20 insertions(+), 9 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index b40d72c9a51..39a5e338ad2 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -18,6 +18,7 @@ import urllib.request import urllib.error from wheel.bdist_wheel import bdist_wheel as _bdist_wheel +from setuptools.command.install import install import paddle from paddle.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension @@ -159,14 +160,11 @@ def get_data_files(): # Assuming 'libflashattn.so' is located in the same directory as setup.py source_lib_path = 'libflashattn.so' - # Specify the destination directory within the package - destination_lib_path = os.path.join(PACKAGE_NAME, 'build/libflashattn.so') - - data_files.append((paddle_lib_path, [source_lib_path])) + data_files.append((".", [source_lib_path])) return data_files -class CachedWheelsCommand(_bdist_wheel): +class CustomWheelsCommand(_bdist_wheel): """ The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot find an existing wheel (which is currently the case for all flash attention installs). We use @@ -174,9 +172,6 @@ class CachedWheelsCommand(_bdist_wheel): wheel available and short-circuits the standard full build pipeline. """ def run(self): - print("88888888888888888888888888888") - # if FORCE_BUILD: - # return super().run() self.run_command('build_ext') super().run() # Determine the version numbers that will be used to determine the correct wheel @@ -202,6 +197,21 @@ def run(self): ) +class CustomInstallCommand(install): + def run(self): + install.run(self) + install_path = self.install_lib + # src + source_lib_path = os.path.abspath('libflashattn.so') + + # 目标链接路径 + destination_lib_path = os.path.join(paddle_lib_path, 'libflashattn.so') + + # 创建软链接 + shutil.move(f"{source_lib_path}", f"{destination_lib_path}") + #os.symlink(source_lib_path, destination_lib_path) + + setup( name=PACKAGE_NAME, version=get_package_version(), @@ -219,7 +229,8 @@ def run(self): "Operating System :: Unix", ], cmdclass={ - "bdist_wheel": CachedWheelsCommand,}, + 'bdist_wheel': CustomWheelsCommand, + 'install': CustomInstallCommand}, python_requires=">=3.7", install_requires=[ "paddle", From 4fd33eaf567f5a837f63bd5b1d2e7383c006d76e Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 20:16:49 +0800 Subject: [PATCH 19/37] updaet --- csrc/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index fbf48d4922e..fba57f0e21a 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -142,14 +142,14 @@ INSTALL(TARGETS flashattn INSTALL(FILES capi/flash_attn.h DESTINATION "include") -add_custom_target(run_my_executable +add_custom_target(build_whl COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel WORKING_DIRECTORY ${CMAKE_BINARY_DIR} DEPENDS flashattn COMMENT "Running my_executable" ) -add_custom_target(default_target DEPENDS run_my_executable) +add_custom_target(default_target DEPENDS build_whl) # set 'default_target' set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target) From f03a1dffd0f3661c80c7d4098f5769a04b3febae Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 20:17:35 +0800 Subject: [PATCH 20/37] updaet --- csrc/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index fba57f0e21a..48679712781 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -146,7 +146,7 @@ add_custom_target(build_whl COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel WORKING_DIRECTORY ${CMAKE_BINARY_DIR} DEPENDS flashattn - COMMENT "Running my_executable" + COMMENT "Running build wheel" ) add_custom_target(default_target DEPENDS build_whl) From d8101084517132fa4f43cd657edd5fc3e28d792f Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 20:21:27 +0800 Subject: [PATCH 21/37] upate --- csrc/setup.py | 41 +---------------------------------------- 1 file changed, 1 insertion(+), 40 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index 39a5e338ad2..39e224df61c 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -67,15 +67,6 @@ CUDA_HOME = find_cuda_home() PACKAGE_NAME = "paddle_flash_attn" -BASE_WHEEL_URL = "https://github.com/PaddlePaddle/flash-attention/releases/download/{tag_name}/{wheel_name}" - -FORCE_BUILD = os.getenv("FLASH_ATTENTION_FORCE_BUILD", "FALSE") == "TRUE" -# For CI, we want the option to build with C++11 ABI since the nvcr images use C++11 ABI -FORCE_CXX11_ABI = os.getenv("FLASH_ATTENTION_FORCE_CXX11_ABI", "FALSE") == "TRUE" -# For CI, we want the option to not add "--threads 4" to nvcc, since the runner can OOM -FORCE_SINGLE_THREAD = os.getenv("FLASH_ATTENTION_FORCE_SINGLE_THREAD", "FALSE") == "TRUE" - - def get_platform(): """ Returns the platform name as used in wheel filenames. @@ -100,33 +91,6 @@ def get_cuda_bare_metal_version(cuda_dir): return raw_output, bare_metal_version -def check_cuda_paddle_binary_vs_bare_metal(cuda_dir): - raw_output, bare_metal_version = get_cuda_bare_metal_version(cuda_dir) - paddle_binary_version = parse(paddle.version.cuda()) - - print("\nCompiling cuda extensions with") - print(raw_output + "from " + cuda_dir + "/bin\n") - - if (bare_metal_version != paddle_binary_version): - raise RuntimeError( - "Cuda extensions are being compiled with a version of Cuda that does " - "not match the version used to compile Pypaddle binaries. " - "Pypaddle binaries were compiled with Cuda {}.\n".format(paddle.version.cuda) - + "In some cases, a minor-version mismatch will not cause later errors: " - "https://github.com/NVIDIA/apex/pull/323#discussion_r287021798. " - "You can try commenting out this check (at your own risk)." - ) - - -def raise_if_cuda_home_none(global_option: str) -> None: - if CUDA_HOME is not None: - return - raise RuntimeError( - f"{global_option} was requested, but nvcc was not found. Are you sure your environment has nvcc available? " - "If you're installing within a container from https://hub.docker.com/r/pypaddle/pypaddle, " - "only images whose names contain 'devel' will provide nvcc." - ) - def _is_cuda_available(): """ Check whether CUDA is available. @@ -141,7 +105,7 @@ def _is_cuda_available(): f"\n Original Error is {e}" ) return False - +check = _is_cuda_available() cmdclass = {} def get_package_version(): @@ -156,10 +120,7 @@ def get_package_version(): def get_data_files(): data_files = [] - - # Assuming 'libflashattn.so' is located in the same directory as setup.py source_lib_path = 'libflashattn.so' - data_files.append((".", [source_lib_path])) return data_files From 48eb6479fd7163bd20164f0e36ad9a3b5f7ec510 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 20:38:28 +0800 Subject: [PATCH 22/37] update --- csrc/setup.py | 90 ++++++++++++++++++++++++++++++++------------------- 1 file changed, 56 insertions(+), 34 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index 39e224df61c..1e7a63293fe 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -1,29 +1,36 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + # Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import sys -import warnings -import os -import re import ast -from pathlib import Path -from packaging.version import parse, Version +import logging +import os import platform +import re import shutil -from setuptools import Command, Extension, setup -from setuptools.command.develop import develop as DevelopCommandBase -from setuptools.command.egg_info import egg_info -from setuptools.command.install import install as InstallCommandBase -from setuptools.command.install_lib import install_lib -from setuptools.dist import Distribution -from setuptools import setup, find_packages -import urllib.request -import urllib.error +import subprocess +import sys +import warnings +from pathlib import Path + +from packaging.version import parse +from setuptools import find_packages, setup +from setuptools.command.install import install as _install from wheel.bdist_wheel import bdist_wheel as _bdist_wheel -from setuptools.command.install import install import paddle -from paddle.utils.cpp_extension import BuildExtension, CppExtension, CUDAExtension from paddle.utils.cpp_extension.extension_utils import find_cuda_home -import subprocess version_detail = sys.version_info python_version = platform.python_version() @@ -55,7 +62,7 @@ # preparing parameters for setup() paddle_version = paddle.version.full_version -cuda_version= paddle.version.cuda_version +cuda_version = paddle.version.cuda_version with open("../../README.md", "r", encoding="utf-8") as fh: @@ -67,6 +74,7 @@ CUDA_HOME = find_cuda_home() PACKAGE_NAME = "paddle_flash_attn" + def get_platform(): """ Returns the platform name as used in wheel filenames. @@ -79,11 +87,13 @@ def get_platform(): elif sys.platform == 'win32': return 'win_amd64' else: - raise ValueError('Unsupported platform: {}'.format(sys.platform)) + raise ValueError(f'Unsupported platform: {sys.platform}') def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output([cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True) + raw_output = subprocess.check_output( + [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True + ) output = raw_output.split() release_idx = output.index("release") + 1 bare_metal_version = parse(output[release_idx].split(",")[0]) @@ -105,12 +115,17 @@ def _is_cuda_available(): f"\n Original Error is {e}" ) return False + + check = _is_cuda_available() cmdclass = {} + def get_package_version(): with open(Path(this_dir) / "../flash_attn" / "__init__.py", "r") as f: - version_match = re.search(r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE) + version_match = re.search( + r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE + ) public_version = ast.literal_eval(version_match.group(1)) local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") if local_version: @@ -118,6 +133,7 @@ def get_package_version(): else: return str(public_version) + def get_data_files(): data_files = [] source_lib_path = 'libflashattn.so' @@ -132,35 +148,40 @@ class CustomWheelsCommand(_bdist_wheel): the environment parameters to detect whether there is already a pre-built version of a compatible wheel available and short-circuits the standard full build pipeline. """ + def run(self): self.run_command('build_ext') super().run() # Determine the version numbers that will be used to determine the correct wheel # We're using the CUDA version used to build paddle, not the one currently installed # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) - paddle_cuda_version = "234" #parse(paddle.version.cuda) + paddle_cuda_version = "234" # parse(paddle.version.cuda) paddle_version_raw = parse(paddle.__version__) python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() flash_version = get_package_version() - cxx11_abi ="" # str(paddle._C.-D_GLIBCXX_USE_CXX11_ABI).upper() + cxx11_abi = "" # str(paddle._C.-D_GLIBCXX_USE_CXX11_ABI).upper() # Determine wheel URL based on CUDA version, paddle version, python version and OS wheel_filename = f'{PACKAGE_NAME}-{flash_version}-cu{cuda_version}-paddle{paddle_version}-{python_version}-{python_version}-{platform_name}.whl' impl_tag, abi_tag, plat_tag = self.get_tag() - original_wheel_name = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + original_wheel_name = ( + f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + ) - #new_wheel_name = wheel_filename - new_wheel_name = f"{self.wheel_dist_name}-{python_version}-{abi_tag}-{plat_tag}" + # new_wheel_name = wheel_filename + new_wheel_name = ( + f"{self.wheel_dist_name}-{python_version}-{abi_tag}-{plat_tag}" + ) shutil.move( f"{self.dist_dir}/{original_wheel_name}.whl", - f"{self.dist_dir}/{new_wheel_name}.whl" - ) + f"{self.dist_dir}/{new_wheel_name}.whl", + ) -class CustomInstallCommand(install): +class CustomInstallCommand(_install): def run(self): - install.run(self) + super().run(self) install_path = self.install_lib # src source_lib_path = os.path.abspath('libflashattn.so') @@ -170,7 +191,7 @@ def run(self): # 创建软链接 shutil.move(f"{source_lib_path}", f"{destination_lib_path}") - #os.symlink(source_lib_path, destination_lib_path) + # os.symlink(source_lib_path, destination_lib_path) setup( @@ -190,8 +211,9 @@ def run(self): "Operating System :: Unix", ], cmdclass={ - 'bdist_wheel': CustomWheelsCommand, - 'install': CustomInstallCommand}, + 'bdist_wheel': CustomWheelsCommand, + 'install': CustomInstallCommand, + }, python_requires=">=3.7", install_requires=[ "paddle", From 7bb6f314f0995d2a1bc67aa95e2a0edd573dd906 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 20:40:18 +0800 Subject: [PATCH 23/37] update --- csrc/setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/setup.py b/csrc/setup.py index 1e7a63293fe..004c891332c 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -181,7 +181,7 @@ def run(self): class CustomInstallCommand(_install): def run(self): - super().run(self) + _install.run(self) install_path = self.install_lib # src source_lib_path = os.path.abspath('libflashattn.so') From 58563ba8c8172dad5439041fc396d61dc120d449 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 20:41:52 +0800 Subject: [PATCH 24/37] udpate --- csrc/setup.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index 004c891332c..11cbb4dabf2 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -186,11 +186,9 @@ def run(self): # src source_lib_path = os.path.abspath('libflashattn.so') - # 目标链接路径 destination_lib_path = os.path.join(paddle_lib_path, 'libflashattn.so') - # 创建软链接 - shutil.move(f"{source_lib_path}", f"{destination_lib_path}") + # shutil.move(f"{source_lib_path}", f"{destination_lib_path}") # os.symlink(source_lib_path, destination_lib_path) From e856a057707bb1d101e1e49bf2558646fb11cfd4 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 6 Dec 2023 20:44:23 +0800 Subject: [PATCH 25/37] update --- csrc/CMakeLists.txt | 78 ++++++++++++++++++++++----------------------- 1 file changed, 39 insertions(+), 39 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 48679712781..b571791a031 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -18,38 +18,38 @@ set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) set(FA2_SOURCES_CU flash_attn/src/cuda_utils.cu - #flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu ) add_library(flashattn SHARED @@ -62,13 +62,13 @@ target_include_directories(flashattn PRIVATE set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu - #flash_attn_with_bias_and_mask/src/cuda_utils.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu + flash_attn_with_bias_and_mask/src/cuda_utils.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu flash_attn_with_bias_and_mask/src/utils.cu) add_library(flashattn_with_bias_mask STATIC From 256a3c6b8a9b13923cb1c7f4e9cae6dd2a726361 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Thu, 7 Dec 2023 16:36:06 +0800 Subject: [PATCH 26/37] update --- csrc/CMakeLists.txt | 107 +++++++++++++++++++++++--------------------- csrc/setup.py | 7 ++- 2 files changed, 60 insertions(+), 54 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index b571791a031..a6b9e6e0059 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -7,7 +7,7 @@ execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} RESULT_VARIABLE GIT_SUBMOD_RESULT) -#cmake -DWITH_ADVANCED +#cmake -DWITH_ADVANCED=ON if (WITH_ADVANCED) add_compile_definitions(PADDLE_WITH_ADVANCED) endif() @@ -18,38 +18,38 @@ set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) set(FA2_SOURCES_CU flash_attn/src/cuda_utils.cu - flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu - flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu - flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu - flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu - flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu + #flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu + #flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu + #flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu + #flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu ) add_library(flashattn SHARED @@ -63,12 +63,12 @@ target_include_directories(flashattn PRIVATE set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu flash_attn_with_bias_and_mask/src/cuda_utils.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu - flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu - flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu + #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu + #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu flash_attn_with_bias_and_mask/src/utils.cu) add_library(flashattn_with_bias_mask STATIC @@ -142,15 +142,22 @@ INSTALL(TARGETS flashattn INSTALL(FILES capi/flash_attn.h DESTINATION "include") -add_custom_target(build_whl - COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel - WORKING_DIRECTORY ${CMAKE_BINARY_DIR} - DEPENDS flashattn - COMMENT "Running build wheel" -) - -add_custom_target(default_target DEPENDS build_whl) - -# set 'default_target' -set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target) +if (WITH_ADVANCED) + set_target_properties(flashattn PROPERTIES + OUTPUT_NAME libflashattn_advanced + PREFIX "" + ) +endif() +if (WITH_ADVANCED) + add_custom_target(build_whl + COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + DEPENDS flashattn + COMMENT "Running build wheel" + ) + + add_custom_target(default_target DEPENDS build_whl) + + set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target) +endif() diff --git a/csrc/setup.py b/csrc/setup.py index 11cbb4dabf2..d6e13600335 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -136,8 +136,9 @@ def get_package_version(): def get_data_files(): data_files = [] - source_lib_path = 'libflashattn.so' - data_files.append((".", [source_lib_path])) + #source_lib_path = 'libflashattn.so' + #data_files.append((".", [source_lib_path])) + data_files.append((".", ['flashattn_advanced.so'])) return data_files @@ -155,8 +156,6 @@ def run(self): # Determine the version numbers that will be used to determine the correct wheel # We're using the CUDA version used to build paddle, not the one currently installed # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) - paddle_cuda_version = "234" # parse(paddle.version.cuda) - paddle_version_raw = parse(paddle.__version__) python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" platform_name = get_platform() flash_version = get_package_version() From af386bf4e232948e1c713339e72eb2b05943272b Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Thu, 7 Dec 2023 16:40:07 +0800 Subject: [PATCH 27/37] update --- csrc/CMakeLists.txt | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index a6b9e6e0059..4c18659114a 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -137,6 +137,9 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + INSTALL(TARGETS flashattn LIBRARY DESTINATION "lib") From 3aca223f32b4117dc44177c9932eb44712fcae34 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Thu, 7 Dec 2023 18:17:05 +0800 Subject: [PATCH 28/37] Update --- csrc/CMakeLists.txt | 76 ++++++++++++++++++++++----------------------- csrc/setup.py | 16 ++++++---- 2 files changed, 48 insertions(+), 44 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 4c18659114a..9261d7d05cc 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -18,38 +18,38 @@ set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) set(FA2_SOURCES_CU flash_attn/src/cuda_utils.cu - #flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu - #flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu - #flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu - #flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu - #flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu + flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu + flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu + flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu + flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu ) add_library(flashattn SHARED @@ -63,12 +63,12 @@ target_include_directories(flashattn PRIVATE set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu flash_attn_with_bias_and_mask/src/cuda_utils.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu - #flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu - #flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim32.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim64.cu + flash_attn_with_bias_and_mask/src/fmha_fwd_with_mask_bias_hdim128.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim32.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu + flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu flash_attn_with_bias_and_mask/src/utils.cu) add_library(flashattn_with_bias_mask STATIC diff --git a/csrc/setup.py b/csrc/setup.py index d6e13600335..77c127e4633 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -138,7 +138,7 @@ def get_data_files(): data_files = [] #source_lib_path = 'libflashattn.so' #data_files.append((".", [source_lib_path])) - data_files.append((".", ['flashattn_advanced.so'])) + data_files.append((".", ['libflashattn_advanced.so'])) return data_files @@ -213,9 +213,13 @@ def run(self): }, python_requires=">=3.7", install_requires=[ - "paddle", - "einops", - "packaging", - "ninja", - ], + "common", + "dual", + "tight>=0.1.0", + "data", + "prox", + "ninja", # Put ninja before paddle if paddle depends on it + "einops", + "packaging", +], ) From 06edc27972b0a8a08a62b42ea410058679fc25bb Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Fri, 8 Dec 2023 10:23:43 +0800 Subject: [PATCH 29/37] update --- csrc/flash_attn/src/flash_bwd_launch_template.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index f3eb8850f0d..d611b5deaee 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -87,8 +87,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { - auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; - // auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; if (smem_size_dq_dk_dv >= 48 * 1024) { C10_CUDA_CHECK(cudaFuncSetAttribute( kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); From 45fcc53e73b3ea9b69098ccfe9878fd26467427d Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Fri, 8 Dec 2023 14:55:17 +0800 Subject: [PATCH 30/37] update --- csrc/CMakeLists.txt | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 9261d7d05cc..0d5e65ee369 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -88,14 +88,9 @@ target_link_libraries(flashattn flashattn_with_bias_mask) add_dependencies(flashattn flashattn_with_bias_mask) +option(NVCC_ARCH_BIN "Set default compute arch to 80" "80") -if (NOT DEFINED NVCC_ARCH_BIN) - message(FATAL_ERROR "NVCC_ARCH_BIN is not defined.") -endif() - -if (NVCC_ARCH_BIN STREQUAL "") - message(FATAL_ERROR "NVCC_ARCH_BIN is not set.") -endif() +message("NVCC_ARCH_BIN is set to: ${NVCC_ARCH_BIN}") STRING(REPLACE "-" ";" FA_NVCC_ARCH_BIN ${NVCC_ARCH_BIN}) From 940a8ae76bdc72fe79d38f9d083c7d629531c6c0 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Fri, 8 Dec 2023 15:12:50 +0800 Subject: [PATCH 31/37] default --- csrc/CMakeLists.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 0d5e65ee369..e52955036da 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -88,7 +88,7 @@ target_link_libraries(flashattn flashattn_with_bias_mask) add_dependencies(flashattn flashattn_with_bias_mask) -option(NVCC_ARCH_BIN "Set default compute arch to 80" "80") +set(NVCC_ARCH_BIN 80 CACHE STRING "CUDA architectures") message("NVCC_ARCH_BIN is set to: ${NVCC_ARCH_BIN}") From 6b6c7a88e490056fd47698b83402fb4c471ab109 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Fri, 8 Dec 2023 15:30:20 +0800 Subject: [PATCH 32/37] update --- csrc/CMakeLists.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index e52955036da..c8c1f142b89 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -1,5 +1,7 @@ cmake_minimum_required(VERSION 3.9 FATAL_ERROR) project(flash-attention LANGUAGES CXX CUDA) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(Git QUIET REQUIRED) @@ -132,8 +134,6 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$) -set(CMAKE_CXX_STANDARD 17) -set(CMAKE_CXX_STANDARD_REQUIRED ON) INSTALL(TARGETS flashattn LIBRARY DESTINATION "lib") From 18ae75621289868c096ea63477271bd227ab5805 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Fri, 8 Dec 2023 17:04:01 +0800 Subject: [PATCH 33/37] update equal --- .../src/flash_fwd_launch_template.h | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 5090605cb56..ae707d0e9d4 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -64,19 +64,21 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { - // Will only return softmax if dropout, to reduce compilation time. - auto kernel = &flash_fwd_kernel; - // auto kernel = &flash_fwd_kernel; - if (smem_size >= 48 * 1024) { - C10_CUDA_CHECK(cudaFuncSetAttribute( - kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); - } - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); - // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); - kernel<<>>(params); - C10_CUDA_KERNEL_LAUNCH_CHECK(); + BOOL_SWITCH(is_equal_qk, Is_equal_seq_qk, [&] { + // Will only return softmax if dropout, to reduce compilation time. + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); }); }); }); From d926c09ca157063b1a8ed673d24fcb358c9a11d4 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Sun, 10 Dec 2023 21:44:29 +0800 Subject: [PATCH 34/37] for so --- csrc/setup.py | 10 +++------- 1 file changed, 3 insertions(+), 7 deletions(-) diff --git a/csrc/setup.py b/csrc/setup.py index 77c127e4633..060268ccae6 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -182,13 +182,9 @@ class CustomInstallCommand(_install): def run(self): _install.run(self) install_path = self.install_lib - # src - source_lib_path = os.path.abspath('libflashattn.so') - - destination_lib_path = os.path.join(paddle_lib_path, 'libflashattn.so') - - # shutil.move(f"{source_lib_path}", f"{destination_lib_path}") - # os.symlink(source_lib_path, destination_lib_path) + source_lib_path = os.path.abspath('libflashattn_advanced.so') + destination_lib_path = os.path.join(paddle_lib_path, 'libflashattn_advanced.so') + shutil.copy(f"{source_lib_path}", f"{destination_lib_path}") setup( From a2714ebebc2113548bd5d47d6f44fde0b4cc4872 Mon Sep 17 00:00:00 2001 From: niuliling123 <51102941+niuliling123@users.noreply.github.com> Date: Mon, 11 Dec 2023 12:27:15 +0800 Subject: [PATCH 35/37] Update CMakeLists.txt --- csrc/CMakeLists.txt | 3 --- 1 file changed, 3 deletions(-) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index c8c1f142b89..9aa83b4fc7d 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -145,9 +145,6 @@ if (WITH_ADVANCED) OUTPUT_NAME libflashattn_advanced PREFIX "" ) -endif() - -if (WITH_ADVANCED) add_custom_target(build_whl COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel WORKING_DIRECTORY ${CMAKE_BINARY_DIR} From a61e35b2bb4195f98811ed90c06ddabf2390c75e Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Mon, 11 Dec 2023 15:18:05 +0800 Subject: [PATCH 36/37] update fa1 mask --- csrc/CMakeLists.txt | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 9aa83b4fc7d..a19b9274487 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -62,6 +62,7 @@ target_include_directories(flashattn PRIVATE flash_attn ${CUTLASS_3_DIR}/include) +if (WITH_ADVANCED) set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu flash_attn_with_bias_and_mask/src/cuda_utils.cu @@ -72,6 +73,12 @@ set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu flash_attn_with_bias_and_mask/src/utils.cu) +else() +set(FA1_SOURCES_CU + flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu + flash_attn_with_bias_and_mask/src/cuda_utils.cu + flash_attn_with_bias_and_mask/src/utils.cu) +endif() add_library(flashattn_with_bias_mask STATIC flash_attn_with_bias_and_mask/ From 600d748a16ebe044054c4cb695978476006baede Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Thu, 14 Dec 2023 13:33:47 +0800 Subject: [PATCH 37/37] Update for fa extends --- csrc/flash_attn.cu | 591 +++++++++++++++++++++++++++++++++++++++++++++ csrc/setup.py | 282 +++++++-------------- 2 files changed, 675 insertions(+), 198 deletions(-) create mode 100644 csrc/flash_attn.cu diff --git a/csrc/flash_attn.cu b/csrc/flash_attn.cu new file mode 100644 index 00000000000..543b7d86d78 --- /dev/null +++ b/csrc/flash_attn.cu @@ -0,0 +1,591 @@ +#pragma once // NOLINT + +#include // NOLINT +#include // NOLINT + +#include +#include + +#include "paddle/extension.h" +#include "capi/flash_attn.h" +static std::pair GenerateRNGState( + const phi::GPUContext& ctx, + const paddle::optional& fixed_seed_offset, + const std::string& rng_name, + const int64_t batch_size, + const int64_t num_heads) { + if (fixed_seed_offset.get_ptr()) { + const int64_t* fixed_seed_offset_data = + fixed_seed_offset.get_ptr()->data(); + uint64_t seed = static_cast(fixed_seed_offset_data[0]); + uint64_t offset = static_cast(fixed_seed_offset_data[1]); + return std::make_pair(seed, offset); + } else { + uint64_t inc = batch_size * num_heads * 32; + std::pair seed_offset_pair; + // Error phi::Generator * gen = ctx.GetGenerator(); + // Error seed_offset_pair = gen->IncrementOffset(inc); + return seed_offset_pair; + } +} + +static std::vector GetAttnMaskDims(const paddle::Tensor* attn_mask) { + std::vector mask_dim_4d; + if (attn_mask) { + const auto& origin_dims = attn_mask->shape(); + auto rank = origin_dims.size(); + //#PADDLE_ENFORCE_GE( + //# rank, + //# 4, + //# phi::errors::InvalidArgument( + //# "The number of dimensions of attn_mask is expected to be greater " + //# "or equal to 4, but recieved %d. The shape of attn_mask is {%s}", + //# rank, + //# origin_dims)); + + int64_t first_dim = 1; + for (int i = 0; i < rank - 3; i++) { + first_dim *= origin_dims[i]; + } + mask_dim_4d = {first_dim, + origin_dims[rank - 3], + origin_dims[rank - 2], + origin_dims[rank - 1]}; + } + return mask_dim_4d; +} + +struct FlashAttnParamsBase { + int batch_size; + // for padded kernel, max_seqlen_q and seqlen_q is the same. + int64_t max_seqlen_q; + // for padded kernel, max_seqlen_k and seqlen_k is the same. + int64_t max_seqlen_k; + int num_heads; + int num_heads_k; + int head_size; + + int seqlen_q_rounded; + int seqlen_k_rounded; + int head_size_rounded; + + bool is_bf16; + float softmax_scale; + std::vector softmax_lse_dims; + + bool causal; + std::vector mask_dims; + const paddle::Tensor* attn_mask_tensor; + + FlashAttnParamsBase(const int _batch_size, + const int64_t _max_seqlen_q, + const int64_t _max_seqlen_k, + const int _num_heads, + const int _num_heads_k, + const int _head_size, + const float _scale, + const bool _causal, + const paddle::DataType q_dtype, + const paddle::optional& attn_mask) + : batch_size(_batch_size), + max_seqlen_q(_max_seqlen_q), + max_seqlen_k(_max_seqlen_k), + num_heads(_num_heads), + num_heads_k(_num_heads), + head_size(_head_size), + softmax_scale(_scale), + causal(_causal), + attn_mask_tensor(attn_mask.get_ptr()) { + is_bf16 = q_dtype == paddle::DataType::BFLOAT16; + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + head_size_rounded = round_multiple(head_size, 32); + seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + softmax_lse_dims = {batch_size, num_heads, seqlen_q_rounded}; + + if (attn_mask_tensor) { + //PADDLE_ENFORCE_NE(causal, + // true, + // phi::errors::InvalidArgument( + // "When attn_mask is set, causal can not be true.")); + + //PADDLE_ENFORCE_EQ( + // attn_mask->dtype(), + // q_dtype, + // phi::errors::InvalidArgument( + // "attn_mask is expected to have the same data type with q.")); + + mask_dims = GetAttnMaskDims(attn_mask_tensor); + } + } +}; + +template +struct FlashAttnFwdParamsV2 : public FlashAttnParamsBase { + float dropout; + bool return_softmax; + uint64_t seed; + uint64_t offset; + paddle::Tensor rng_state; + paddle::Tensor* softmax; + paddle::Tensor* softmax_lse; + paddle::Tensor* seed_offset; + + FlashAttnFwdParamsV2(const phi::GPUContext& ctx, + const int _batch_size, + const int64_t _max_seqlen_q, + const int64_t _max_seqlen_k, + const int _num_heads, + const int _num_heads_k, + const int _head_size, + const float _dropout, + const float _scale, + const bool _causal, + const bool _return_softmax, + const paddle::DataType q_dtype, + const bool is_test, + const std::string& rng_name, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + paddle::Tensor* _softmax, + paddle::Tensor* _softmax_lse, + paddle::Tensor* _seed_offset) + : FlashAttnParamsBase(_batch_size, + _max_seqlen_q, + _max_seqlen_k, + _num_heads, + _num_heads_k, + _head_size, + _scale, + _causal, + q_dtype, + attn_mask), + dropout(_dropout), + return_softmax(_return_softmax), + softmax(_softmax), + softmax_lse(_softmax_lse), + seed_offset(_seed_offset) { + dropout = is_test ? 0.0f : _dropout; + + // (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t + // with the same size. + rng_state = paddle::empty({2}, phi::CppTypeToDataType::Type()); + + auto seed_offset_pair = GenerateRNGState( + ctx, fixed_seed_offset, rng_name, batch_size, num_heads); + seed = seed_offset_pair.first; + offset = seed_offset_pair.second; + + seed_offset->reshape({2}); + int64_t seed_offset_data[2]; + seed_offset_data[0] = static_cast(seed); + seed_offset_data[1] = static_cast(offset); + //tensor.cc + softmax_lse->reshape(softmax_lse_dims); + // Error paddle::Tensor tp = paddle::empty(softmax_lse_dims, phi::CppTypeToDataType::Type()); + + if (return_softmax) { + //PADDLE_ENFORCE_EQ( + // dropout > 0.0f, + // true, + // phi::errors::InvalidArgument( + // "return_softmax is only supported when dropout > 0.0")); + + softmax->reshape( + {batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded}); + // Error ctx.template Alloc(softmax); + } + } +}; + +struct FlashAttnBwdParamsV2 : public FlashAttnParamsBase { + float dropout; + uint64_t seed; + uint64_t offset; + paddle::Tensor softmax_d; + paddle::Tensor dq_accum; + paddle::Tensor rng_state; + + FlashAttnBwdParamsV2(const phi::GPUContext& ctx, + const int _batch_size, + const int64_t _max_seqlen_q, + const int64_t _max_seqlen_k, + const int _num_heads, + const int _num_heads_k, + const int _head_size, + const float _dropout, + const float _scale, + const bool _causal, + const paddle::DataType q_dtype, + const paddle::optional& attn_mask, + const int64_t* seed_offset_data) + : FlashAttnParamsBase(_batch_size, + _max_seqlen_q, + _max_seqlen_k, + _num_heads, + _num_heads_k, + _head_size, + _scale, + _causal, + q_dtype, + attn_mask), + dropout(_dropout) { + seed = static_cast(seed_offset_data[0]); + offset = static_cast(seed_offset_data[1]); + + // (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t + // with the same size. + rng_state = paddle::empty({2}, phi::CppTypeToDataType::Type()); + + // gradient of softmax_lse + softmax_d = paddle::empty(softmax_lse_dims,phi::CppTypeToDataType::Type()); + + // an internal gradient of q, which will be further accumulated. + dq_accum = paddle::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded},phi::CppTypeToDataType::Type()); + } +}; + +static void CheckFlashAttnStatus(const bool status) { + // Error PADDLE_ENFORCE_EQ(status, + // Error true, + // Error phi::errors::External( + // Error "Error in Flash-Attention, detail information is: %s", + // Error phi::dynload::flash_attn_error())); +} + +static void RaiseNotSupportedError() { + // ErrorPADDLE_THROW( + // Error phi::errors::Unimplemented("FlashAttention is unsupported, please check " + // Error "the GPU compability and CUDA Version.")); +} + +template +void FlashAttnKernel(const Context& ctx, + const paddle::Tensor& q, + const paddle::Tensor& k, + const paddle::Tensor& v, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + paddle::Tensor* out, + paddle::Tensor* softmax, + paddle::Tensor* softmax_lse, + paddle::Tensor* seed_offset) { + // q, k, v [batch_size, seq_len, num_heads, head_dim] + const auto& dims = q.shape(); +//Error PADDLE_ENFORCE_EQ(dims.size(), +//Error 4, +//Error phi::errors::InvalidArgument( +//Error "flash_attn receive input with dim " +//Error "[batch_size, seq_len, num_heads, head_dim]")); +//Error + const int64_t batch_size = dims[0]; + const int64_t seqlen_q = dims[1]; + const int64_t num_heads = dims[2]; + const int64_t head_size = dims[3]; + const int64_t seqlen_k = k.shape()[1]; + const int64_t num_heads_k = k.shape()[2]; + + // TODO(umiswing): Add check shape + + const float softmax_scale = 1.0f / std::sqrt(head_size); + const float softmax_unscale = std::sqrt(head_size); + + FlashAttnFwdParamsV2 params = FlashAttnFwdParamsV2(ctx, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + softmax_scale, + causal, + return_softmax, + q.dtype(), + is_test, + rng_name, + fixed_seed_offset, + attn_mask, + softmax, + softmax_lse, + seed_offset); + + //VLOG(10) << "[FlashAttn Forward] q.shape=[" << q.shape() << "], k.shape=[" + // << k.shape() << "], v.shape=[" << v.shape() << "]"; + //VLOG(10) << "[FlashAttn Forward] dropout=" << dropout + // << ", seed=" << params.seed << ", offset=" << params.offset; + //VLOG(10) << "[FlashAttn Forward] softmax_scale=" << softmax_scale + // << ", softmax_unscale=" << softmax_unscale; + //if (attn_mask.get_ptr()) { + // VLOG(10) << "[FlashAttn Forward] attn_mask.shape=[" + // << (attn_mask.get_ptr())->shape() << "]"; + //} + + //Error ctx.template Alloc(out); + + cudaStream_t stream = q.stream(); + + bool succ = flash_attn_fwd( + q.data(), + k.data(), + v.data(), + params.rng_state.data(), + out->data(), + params.return_softmax ? params.softmax->data() : nullptr, + params.softmax_lse->data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.softmax_scale, + softmax_unscale, + params.causal, + params.return_softmax, + params.is_bf16, + stream, + params.seed, + params.offset, + params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, + params.mask_dims.data()); + CheckFlashAttnStatus(succ); +} +std::vector FaFwd( + const paddle::Tensor& q, + const paddle::Tensor& k, + const paddle::Tensor& v, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name) { + return_softmax = false; + paddle::Tensor out = paddle::empty(q.shape(), q.type()); + // out.set_layout(q.layout()); + paddle::Tensor softmax = paddle::empty({1}, q.type()); + paddle::Tensor softmax_lse = paddle::empty({1}, q.type()); + paddle::Tensor seed_offset = paddle::empty({1}, q.type()); + auto place = q.place(); + const phi::GPUContext *ctx{nullptr}; + //Error auto ctx = phi::GPUContext(); + //Error auto ctx = new phi::GPUContext(place); + switch(q.type()){ + case paddle::DataType::FLOAT16: + FlashAttnKernel(*ctx,q,k,v, fixed_seed_offset, attn_mask, dropout, causal, return_softmax, is_test, rng_name, &out, &softmax, &softmax_lse, &seed_offset); + break; + case paddle::DataType::BFLOAT16: + FlashAttnKernel(*ctx,q,k,v, fixed_seed_offset, attn_mask, dropout, causal, return_softmax, is_test, rng_name, &out, &softmax, &softmax_lse, &seed_offset); + break; + default: + break; + // Error + } + return {out, softmax, softmax_lse, seed_offset}; +} + +std::vector> FaFwdInferShape( + std::vector q_shape, + std::vector k_shape, + std::vector v_shape, + std::vector fixed_seed_offset, + std::vector mask_shape, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name) { + return {q_shape, k_shape, v_shape, mask_shape}; +} + +PD_BUILD_OP(flash_attn_with_mask) + .Inputs({"q", "k", "v", "fixed_seed_offset","attn_mask"}) + .Outputs({"out", "softmax", "softmax_lse","seed_offset"}) + .Attrs({"dropout: float","causal:bool", "return_softmax:bool","is_test:bool","rng_name:std::string"}) + .SetKernelFn(PD_KERNEL(FaFwd)) + .SetInferShapeFn(PD_INFER_SHAPE(FaFwdInferShape)); + +template +void FlashAttnGradKernel(const Context& ctx, + const paddle::Tensor& q, + const paddle::Tensor& k, + const paddle::Tensor& v, + const paddle::Tensor& out, + const paddle::Tensor& softmax_lse, + const paddle::Tensor& seed_offset, + const paddle::optional& attn_mask, + const paddle::Tensor& dout, + float dropout, + bool causal, + paddle::Tensor* dq, + paddle::Tensor* dk, + paddle::Tensor* dv) { + void* dq_ptr = nullptr; + void* dk_ptr = nullptr; + void* dv_ptr = nullptr; + + // Error ctx.template Alloc(dq); + dq_ptr = dq->data(); + + paddle::Tensor dk_tmp; + dk_tmp = paddle::empty_like(k, q.type()); + dk_ptr = dk_tmp.data(); + + paddle::Tensor dv_tmp; + dv_tmp = paddle::empty_like(v, q.type()); + dv_ptr = dv_tmp.data(); + + const cudaStream_t stream = q.stream(); + + // q, k, v [batch_size, seq_len, num_heads, head_dim] + const auto& dims = q.shape(); + + const int64_t batch_size = dims[0]; + const int64_t seqlen_q = dims[1]; + const int64_t num_heads = dims[2]; + const int64_t head_size_og = dout.shape()[3]; + const int64_t head_size = dims[3]; + const int64_t seqlen_k = k.shape()[1]; + const int64_t num_heads_k = k.shape()[2]; + + // TODO(umiswing): add shape check + // Error PADDLE_ENFORCE_EQ( + // Error head_size_og, + // Error head_size, + // Error phi::errors::InvalidArgument( + // Error "flash_attn_bwd receive input with head_size_og == head_size")); + + const float softmax_scale = 1.0f / std::sqrt(head_size); + const float softmax_unscale = std::sqrt(head_size); + + FlashAttnBwdParamsV2 params = + FlashAttnBwdParamsV2(ctx, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + softmax_scale, + causal, + q.dtype(), + attn_mask, + seed_offset.data()); + + // Error VLOG(10) << "[FlashAttn Forward] q.shape=[" << q.shape() << "], k.shape=[" + // Error << k.shape() << "], v.shape=[" << v.shape() << "]"; + // Error VLOG(10) << "[FlashAttn Forward] dropout=" << dropout + // Error << ", seed=" << params.seed << ", offset=" << params.offset; + // Error VLOG(10) << "[FlashAttn Forward] softmax_scale=" << softmax_scale + // Error << ", softmax_unscale=" << softmax_unscale; + // Error if (attn_mask.get_ptr()) { + // Error VLOG(10) << "[FlashAttn Backward] attn_mask.shape=[" + // Error << (attn_mask.get_ptr())->shape() << "]"; + // Error } +#ifdef PADDLE_WITH_ADVANCED + int num_splits = 1; // Error get_num_split(); +#else + int num_splits = 0; // Error get_num_split(); +#endif + + bool succ = flash_attn_bwd( + dout.data(), + q.data(), + k.data(), + v.data(), + out.data(), + params.softmax_d.data(), + softmax_lse.data(), + params.rng_state.data(), + dq_ptr, + dk_ptr, + dv_ptr, + params.dq_accum.data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.softmax_scale, + softmax_unscale, + params.causal, + params.is_bf16, + num_splits, + stream, + params.seed, + params.offset, + params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, + params.attn_mask_tensor ? params.mask_dims.data() : nullptr); + CheckFlashAttnStatus(succ); +} + +std::vector FaBwd( + const paddle::Tensor& q, + const paddle::Tensor& k, + const paddle::Tensor& v, + const paddle::Tensor& out, + const paddle::Tensor& softmax_lse, + const paddle::Tensor& seed_offset, + const paddle::optional& attn_mask, + const paddle::Tensor& dout, + float dropout, + bool causal) { + + paddle::Tensor dq = paddle::empty(q.shape(), q.type()); + paddle::Tensor dk = paddle::empty(q.shape(), q.type()); + paddle::Tensor dv = paddle::empty(q.shape(), q.type()); + const phi::GPUContext *ctx{nullptr}; + //Error auto ctx = phi::GPUContext(); + switch(q.type()){ + case paddle::DataType::FLOAT16: + FlashAttnGradKernel(*ctx,q,k,v,out, softmax_lse, seed_offset, attn_mask, dout, dropout, causal,&dq, &dk, &dv); + break; + case paddle::DataType::BFLOAT16: + FlashAttnGradKernel(*ctx,q,k,v,out, softmax_lse, seed_offset, attn_mask, dout, dropout, causal, &dq, &dk, &dv); + break; + default: + break; + // Error + } + return {dq, dk, dv}; +} + +std::vector> FaBwdInferShape( + std::vector q_shape, + std::vector k_shape, + std::vector v_shape, + std::vector out_shape, + std::vector softmax_lse, + std::vector seed_offset, + std::vector mask_shape, + std::vector dout, + float dropout, + bool causal) { + return {q_shape, k_shape, v_shape, out_shape, softmax_lse,seed_offset, mask_shape,dout}; +} + +PD_BUILD_OP(flash_attn_with_mask_grad) + .Inputs({"q", "k", "v", "out", "softmax_lse","seed_offset", "attn_mask","dout"}) + .Outputs({"dq", "dk", "dv"}) + .Attrs({"dropout: float","causal:bool"}) + .SetKernelFn(PD_KERNEL(FaBwd)) + .SetInferShapeFn(PD_INFER_SHAPE(FaBwdInferShape)); diff --git a/csrc/setup.py b/csrc/setup.py index 060268ccae6..4c22485f83c 100644 --- a/csrc/setup.py +++ b/csrc/setup.py @@ -1,6 +1,4 @@ -# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); +#licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # @@ -11,211 +9,99 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -# Adapted from https://github.com/NVIDIA/apex/blob/master/setup.py -import ast -import logging -import os -import platform -import re -import shutil -import subprocess -import sys -import warnings +import multiprocessing from pathlib import Path +import os -from packaging.version import parse -from setuptools import find_packages, setup -from setuptools.command.install import install as _install -from wheel.bdist_wheel import bdist_wheel as _bdist_wheel - -import paddle -from paddle.utils.cpp_extension.extension_utils import find_cuda_home - -version_detail = sys.version_info -python_version = platform.python_version() -version = version_detail[0] + version_detail[1] / 10 -env_version = os.getenv("PY_VERSION") - -if version < 3.7: - raise RuntimeError( - f"Paddle only supports Python version >= 3.7 now," - f"you are using Python {python_version}" - ) -elif env_version is None: - print(f"export PY_VERSION = { python_version }") - os.environ["PY_VERSION"] = python_version - -elif env_version != version: - warnings.warn( - f"You set PY_VERSION={env_version}, but" - f"your current python environment is {version}" - f"we will use your current python version to execute" - ) - os.environ["PY_VERSION"] = python_version +def get_gencode_flags(): + import paddle -paddle_include_path = paddle.sysconfig.get_include() -paddle_lib_path = paddle.sysconfig.get_lib() + prop = paddle.device.cuda.get_device_properties() + cc = prop.major * 10 + prop.minor + return ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)] -print("Paddle Include Path:", paddle_include_path) -print("Paddle Lib Path:", paddle_lib_path) -# preparing parameters for setup() -paddle_version = paddle.version.full_version -cuda_version = paddle.version.cuda_version +def run(func): + p = multiprocessing.Process(target=func) + p.start() + p.join() -with open("../../README.md", "r", encoding="utf-8") as fh: - long_description = fh.read() +def change_pwd(): + path = os.path.dirname(__file__) + if path: + os.chdir(path) -# ninja build does not work unless include_dirs are abs path this_dir = os.path.dirname(os.path.abspath(__file__)) -CUDA_HOME = find_cuda_home() -PACKAGE_NAME = "paddle_flash_attn" - - -def get_platform(): - """ - Returns the platform name as used in wheel filenames. - """ - if sys.platform.startswith('linux'): - return 'linux_x86_64' - elif sys.platform == 'darwin': - mac_version = '.'.join(platform.mac_ver()[0].split('.')[:2]) - return f'macosx_{mac_version}_x86_64' - elif sys.platform == 'win32': - return 'win_amd64' - else: - raise ValueError(f'Unsupported platform: {sys.platform}') - - -def get_cuda_bare_metal_version(cuda_dir): - raw_output = subprocess.check_output( - [cuda_dir + "/bin/nvcc", "-V"], universal_newlines=True +def setup_fused_ln(): + from paddle.utils.cpp_extension import CUDAExtension, setup + + gencode_flags = get_gencode_flags() + change_pwd() + setup( + name="flash_attn", + ext_modules=CUDAExtension( + sources=[ + "flash_attn.cu", + "flash_attn/src/cuda_utils.cu", + "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", + "flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", + "flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", + "flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", + "flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu", + "flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu", + "flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", + "flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", + "flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", + "flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu", + "flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu", + "flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", + "flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", + "flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu", + "flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu", + "flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", + "flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"] + gencode_flags, + "nvcc": [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", + # "--ptxas-options=-O2", + "-lineinfo" + ] + }, + include_dirs=[ + "flash_attn", + "flash_attn/src", + "cutlass/include", + ], + ), ) - output = raw_output.split() - release_idx = output.index("release") + 1 - bare_metal_version = parse(output[release_idx].split(",")[0]) - - return raw_output, bare_metal_version - - -def _is_cuda_available(): - """ - Check whether CUDA is available. - """ - try: - assert len(paddle.static.cuda_places()) > 0 - return True - except Exception as e: - logging.warning( - "You are using GPU version PaddlePaddle, but there is no GPU " - "detected on your machine. Maybe CUDA devices is not set properly." - f"\n Original Error is {e}" - ) - return False - - -check = _is_cuda_available() -cmdclass = {} - - -def get_package_version(): - with open(Path(this_dir) / "../flash_attn" / "__init__.py", "r") as f: - version_match = re.search( - r"^__version__\s*=\s*(.*)$", f.read(), re.MULTILINE - ) - public_version = ast.literal_eval(version_match.group(1)) - local_version = os.environ.get("FLASH_ATTN_LOCAL_VERSION") - if local_version: - return f"{public_version}+{local_version}" - else: - return str(public_version) - - -def get_data_files(): - data_files = [] - #source_lib_path = 'libflashattn.so' - #data_files.append((".", [source_lib_path])) - data_files.append((".", ['libflashattn_advanced.so'])) - return data_files - - -class CustomWheelsCommand(_bdist_wheel): - """ - The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot - find an existing wheel (which is currently the case for all flash attention installs). We use - the environment parameters to detect whether there is already a pre-built version of a compatible - wheel available and short-circuits the standard full build pipeline. - """ - - def run(self): - self.run_command('build_ext') - super().run() - # Determine the version numbers that will be used to determine the correct wheel - # We're using the CUDA version used to build paddle, not the one currently installed - # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) - python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" - platform_name = get_platform() - flash_version = get_package_version() - cxx11_abi = "" # str(paddle._C.-D_GLIBCXX_USE_CXX11_ABI).upper() - - # Determine wheel URL based on CUDA version, paddle version, python version and OS - wheel_filename = f'{PACKAGE_NAME}-{flash_version}-cu{cuda_version}-paddle{paddle_version}-{python_version}-{python_version}-{platform_name}.whl' - impl_tag, abi_tag, plat_tag = self.get_tag() - original_wheel_name = ( - f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" - ) - - # new_wheel_name = wheel_filename - new_wheel_name = ( - f"{self.wheel_dist_name}-{python_version}-{abi_tag}-{plat_tag}" - ) - shutil.move( - f"{self.dist_dir}/{original_wheel_name}.whl", - f"{self.dist_dir}/{new_wheel_name}.whl", - ) - - -class CustomInstallCommand(_install): - def run(self): - _install.run(self) - install_path = self.install_lib - source_lib_path = os.path.abspath('libflashattn_advanced.so') - destination_lib_path = os.path.join(paddle_lib_path, 'libflashattn_advanced.so') - shutil.copy(f"{source_lib_path}", f"{destination_lib_path}") -setup( - name=PACKAGE_NAME, - version=get_package_version(), - packages=find_packages(), - data_files=get_data_files(), - package_data={PACKAGE_NAME: ['build/libflashattn.so']}, - author_email="Paddle-better@baidu.com", - description="Flash Attention: Fast and Memory-Efficient Exact Attention", - long_description=long_description, - long_description_content_type="text/markdown", - url="https://github.com/PaddlePaddle/flash-attention", - classifiers=[ - "Programming Language :: Python :: 37", - "License :: OSI Approved :: BSD License", - "Operating System :: Unix", - ], - cmdclass={ - 'bdist_wheel': CustomWheelsCommand, - 'install': CustomInstallCommand, - }, - python_requires=">=3.7", - install_requires=[ - "common", - "dual", - "tight>=0.1.0", - "data", - "prox", - "ninja", # Put ninja before paddle if paddle depends on it - "einops", - "packaging", -], -) +run(setup_fused_ln)