diff --git a/csrc/CMakeLists.txt b/csrc/CMakeLists.txt index 5b305fa98c0..a19b9274487 100644 --- a/csrc/CMakeLists.txt +++ b/csrc/CMakeLists.txt @@ -1,5 +1,7 @@ cmake_minimum_required(VERSION 3.9 FATAL_ERROR) project(flash-attention LANGUAGES CXX CUDA) +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CXX_STANDARD_REQUIRED ON) find_package(Git QUIET REQUIRED) @@ -7,6 +9,11 @@ execute_process(COMMAND ${GIT_EXECUTABLE} submodule update --init --recursive WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR} RESULT_VARIABLE GIT_SUBMOD_RESULT) +#cmake -DWITH_ADVANCED=ON +if (WITH_ADVANCED) + add_compile_definitions(PADDLE_WITH_ADVANCED) +endif() + add_definitions("-DFLASH_ATTN_WITH_TORCH=0") set(CUTLASS_3_DIR ${CMAKE_CURRENT_SOURCE_DIR}/cutlass) @@ -55,6 +62,7 @@ target_include_directories(flashattn PRIVATE flash_attn ${CUTLASS_3_DIR}/include) +if (WITH_ADVANCED) set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu flash_attn_with_bias_and_mask/src/cuda_utils.cu @@ -65,6 +73,12 @@ set(FA1_SOURCES_CU flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim64.cu flash_attn_with_bias_and_mask/src/fmha_bwd_with_mask_bias_hdim128.cu flash_attn_with_bias_and_mask/src/utils.cu) +else() +set(FA1_SOURCES_CU + flash_attn_with_bias_and_mask/flash_attn_with_bias_mask.cu + flash_attn_with_bias_and_mask/src/cuda_utils.cu + flash_attn_with_bias_and_mask/src/utils.cu) +endif() add_library(flashattn_with_bias_mask STATIC flash_attn_with_bias_and_mask/ @@ -83,18 +97,14 @@ target_link_libraries(flashattn flashattn_with_bias_mask) add_dependencies(flashattn flashattn_with_bias_mask) +set(NVCC_ARCH_BIN 80 CACHE STRING "CUDA architectures") -if (NOT DEFINED NVCC_ARCH_BIN) - message(FATAL_ERROR "NVCC_ARCH_BIN is not defined.") -endif() - -if (NVCC_ARCH_BIN STREQUAL "") - message(FATAL_ERROR "NVCC_ARCH_BIN is not set.") -endif() +message("NVCC_ARCH_BIN is set to: ${NVCC_ARCH_BIN}") STRING(REPLACE "-" ";" FA_NVCC_ARCH_BIN ${NVCC_ARCH_BIN}) set(FA_GENCODE_OPTION "SHELL:") + foreach(arch ${FA_NVCC_ARCH_BIN}) if(${arch} GREATER_EQUAL 80) set(FA_GENCODE_OPTION "${FA_GENCODE_OPTION} -gencode arch=compute_${arch},code=sm_${arch}") @@ -131,7 +141,25 @@ target_compile_options(flashattn_with_bias_mask PRIVATE $<$) + INSTALL(TARGETS flashattn LIBRARY DESTINATION "lib") INSTALL(FILES capi/flash_attn.h DESTINATION "include") + +if (WITH_ADVANCED) + set_target_properties(flashattn PROPERTIES + OUTPUT_NAME libflashattn_advanced + PREFIX "" + ) + add_custom_target(build_whl + COMMAND ${CMAKE_COMMAND} -E env python ${CMAKE_SOURCE_DIR}/setup.py bdist_wheel + WORKING_DIRECTORY ${CMAKE_BINARY_DIR} + DEPENDS flashattn + COMMENT "Running build wheel" + ) + + add_custom_target(default_target DEPENDS build_whl) + + set_property(DIRECTORY PROPERTY DEFAULT_TARGET default_target) +endif() diff --git a/csrc/flash_attn.cu b/csrc/flash_attn.cu new file mode 100644 index 00000000000..543b7d86d78 --- /dev/null +++ b/csrc/flash_attn.cu @@ -0,0 +1,591 @@ +#pragma once // NOLINT + +#include // NOLINT +#include // NOLINT + +#include +#include + +#include "paddle/extension.h" +#include "capi/flash_attn.h" +static std::pair GenerateRNGState( + const phi::GPUContext& ctx, + const paddle::optional& fixed_seed_offset, + const std::string& rng_name, + const int64_t batch_size, + const int64_t num_heads) { + if (fixed_seed_offset.get_ptr()) { + const int64_t* fixed_seed_offset_data = + fixed_seed_offset.get_ptr()->data(); + uint64_t seed = static_cast(fixed_seed_offset_data[0]); + uint64_t offset = static_cast(fixed_seed_offset_data[1]); + return std::make_pair(seed, offset); + } else { + uint64_t inc = batch_size * num_heads * 32; + std::pair seed_offset_pair; + // Error phi::Generator * gen = ctx.GetGenerator(); + // Error seed_offset_pair = gen->IncrementOffset(inc); + return seed_offset_pair; + } +} + +static std::vector GetAttnMaskDims(const paddle::Tensor* attn_mask) { + std::vector mask_dim_4d; + if (attn_mask) { + const auto& origin_dims = attn_mask->shape(); + auto rank = origin_dims.size(); + //#PADDLE_ENFORCE_GE( + //# rank, + //# 4, + //# phi::errors::InvalidArgument( + //# "The number of dimensions of attn_mask is expected to be greater " + //# "or equal to 4, but recieved %d. The shape of attn_mask is {%s}", + //# rank, + //# origin_dims)); + + int64_t first_dim = 1; + for (int i = 0; i < rank - 3; i++) { + first_dim *= origin_dims[i]; + } + mask_dim_4d = {first_dim, + origin_dims[rank - 3], + origin_dims[rank - 2], + origin_dims[rank - 1]}; + } + return mask_dim_4d; +} + +struct FlashAttnParamsBase { + int batch_size; + // for padded kernel, max_seqlen_q and seqlen_q is the same. + int64_t max_seqlen_q; + // for padded kernel, max_seqlen_k and seqlen_k is the same. + int64_t max_seqlen_k; + int num_heads; + int num_heads_k; + int head_size; + + int seqlen_q_rounded; + int seqlen_k_rounded; + int head_size_rounded; + + bool is_bf16; + float softmax_scale; + std::vector softmax_lse_dims; + + bool causal; + std::vector mask_dims; + const paddle::Tensor* attn_mask_tensor; + + FlashAttnParamsBase(const int _batch_size, + const int64_t _max_seqlen_q, + const int64_t _max_seqlen_k, + const int _num_heads, + const int _num_heads_k, + const int _head_size, + const float _scale, + const bool _causal, + const paddle::DataType q_dtype, + const paddle::optional& attn_mask) + : batch_size(_batch_size), + max_seqlen_q(_max_seqlen_q), + max_seqlen_k(_max_seqlen_k), + num_heads(_num_heads), + num_heads_k(_num_heads), + head_size(_head_size), + softmax_scale(_scale), + causal(_causal), + attn_mask_tensor(attn_mask.get_ptr()) { + is_bf16 = q_dtype == paddle::DataType::BFLOAT16; + + auto round_multiple = [](int x, int m) { return (x + m - 1) / m * m; }; + head_size_rounded = round_multiple(head_size, 32); + seqlen_q_rounded = round_multiple(max_seqlen_q, 128); + seqlen_k_rounded = round_multiple(max_seqlen_k, 128); + + softmax_lse_dims = {batch_size, num_heads, seqlen_q_rounded}; + + if (attn_mask_tensor) { + //PADDLE_ENFORCE_NE(causal, + // true, + // phi::errors::InvalidArgument( + // "When attn_mask is set, causal can not be true.")); + + //PADDLE_ENFORCE_EQ( + // attn_mask->dtype(), + // q_dtype, + // phi::errors::InvalidArgument( + // "attn_mask is expected to have the same data type with q.")); + + mask_dims = GetAttnMaskDims(attn_mask_tensor); + } + } +}; + +template +struct FlashAttnFwdParamsV2 : public FlashAttnParamsBase { + float dropout; + bool return_softmax; + uint64_t seed; + uint64_t offset; + paddle::Tensor rng_state; + paddle::Tensor* softmax; + paddle::Tensor* softmax_lse; + paddle::Tensor* seed_offset; + + FlashAttnFwdParamsV2(const phi::GPUContext& ctx, + const int _batch_size, + const int64_t _max_seqlen_q, + const int64_t _max_seqlen_k, + const int _num_heads, + const int _num_heads_k, + const int _head_size, + const float _dropout, + const float _scale, + const bool _causal, + const bool _return_softmax, + const paddle::DataType q_dtype, + const bool is_test, + const std::string& rng_name, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + paddle::Tensor* _softmax, + paddle::Tensor* _softmax_lse, + paddle::Tensor* _seed_offset) + : FlashAttnParamsBase(_batch_size, + _max_seqlen_q, + _max_seqlen_k, + _num_heads, + _num_heads_k, + _head_size, + _scale, + _causal, + q_dtype, + attn_mask), + dropout(_dropout), + return_softmax(_return_softmax), + softmax(_softmax), + softmax_lse(_softmax_lse), + seed_offset(_seed_offset) { + dropout = is_test ? 0.0f : _dropout; + + // (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t + // with the same size. + rng_state = paddle::empty({2}, phi::CppTypeToDataType::Type()); + + auto seed_offset_pair = GenerateRNGState( + ctx, fixed_seed_offset, rng_name, batch_size, num_heads); + seed = seed_offset_pair.first; + offset = seed_offset_pair.second; + + seed_offset->reshape({2}); + int64_t seed_offset_data[2]; + seed_offset_data[0] = static_cast(seed); + seed_offset_data[1] = static_cast(offset); + //tensor.cc + softmax_lse->reshape(softmax_lse_dims); + // Error paddle::Tensor tp = paddle::empty(softmax_lse_dims, phi::CppTypeToDataType::Type()); + + if (return_softmax) { + //PADDLE_ENFORCE_EQ( + // dropout > 0.0f, + // true, + // phi::errors::InvalidArgument( + // "return_softmax is only supported when dropout > 0.0")); + + softmax->reshape( + {batch_size, num_heads, seqlen_q_rounded, seqlen_k_rounded}); + // Error ctx.template Alloc(softmax); + } + } +}; + +struct FlashAttnBwdParamsV2 : public FlashAttnParamsBase { + float dropout; + uint64_t seed; + uint64_t offset; + paddle::Tensor softmax_d; + paddle::Tensor dq_accum; + paddle::Tensor rng_state; + + FlashAttnBwdParamsV2(const phi::GPUContext& ctx, + const int _batch_size, + const int64_t _max_seqlen_q, + const int64_t _max_seqlen_k, + const int _num_heads, + const int _num_heads_k, + const int _head_size, + const float _dropout, + const float _scale, + const bool _causal, + const paddle::DataType q_dtype, + const paddle::optional& attn_mask, + const int64_t* seed_offset_data) + : FlashAttnParamsBase(_batch_size, + _max_seqlen_q, + _max_seqlen_k, + _num_heads, + _num_heads_k, + _head_size, + _scale, + _causal, + q_dtype, + attn_mask), + dropout(_dropout) { + seed = static_cast(seed_offset_data[0]); + offset = static_cast(seed_offset_data[1]); + + // (umiswing): There is no suitable kernel for uint64_t, allocate in int64_t + // with the same size. + rng_state = paddle::empty({2}, phi::CppTypeToDataType::Type()); + + // gradient of softmax_lse + softmax_d = paddle::empty(softmax_lse_dims,phi::CppTypeToDataType::Type()); + + // an internal gradient of q, which will be further accumulated. + dq_accum = paddle::empty({batch_size, num_heads, seqlen_q_rounded, head_size_rounded},phi::CppTypeToDataType::Type()); + } +}; + +static void CheckFlashAttnStatus(const bool status) { + // Error PADDLE_ENFORCE_EQ(status, + // Error true, + // Error phi::errors::External( + // Error "Error in Flash-Attention, detail information is: %s", + // Error phi::dynload::flash_attn_error())); +} + +static void RaiseNotSupportedError() { + // ErrorPADDLE_THROW( + // Error phi::errors::Unimplemented("FlashAttention is unsupported, please check " + // Error "the GPU compability and CUDA Version.")); +} + +template +void FlashAttnKernel(const Context& ctx, + const paddle::Tensor& q, + const paddle::Tensor& k, + const paddle::Tensor& v, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name, + paddle::Tensor* out, + paddle::Tensor* softmax, + paddle::Tensor* softmax_lse, + paddle::Tensor* seed_offset) { + // q, k, v [batch_size, seq_len, num_heads, head_dim] + const auto& dims = q.shape(); +//Error PADDLE_ENFORCE_EQ(dims.size(), +//Error 4, +//Error phi::errors::InvalidArgument( +//Error "flash_attn receive input with dim " +//Error "[batch_size, seq_len, num_heads, head_dim]")); +//Error + const int64_t batch_size = dims[0]; + const int64_t seqlen_q = dims[1]; + const int64_t num_heads = dims[2]; + const int64_t head_size = dims[3]; + const int64_t seqlen_k = k.shape()[1]; + const int64_t num_heads_k = k.shape()[2]; + + // TODO(umiswing): Add check shape + + const float softmax_scale = 1.0f / std::sqrt(head_size); + const float softmax_unscale = std::sqrt(head_size); + + FlashAttnFwdParamsV2 params = FlashAttnFwdParamsV2(ctx, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + softmax_scale, + causal, + return_softmax, + q.dtype(), + is_test, + rng_name, + fixed_seed_offset, + attn_mask, + softmax, + softmax_lse, + seed_offset); + + //VLOG(10) << "[FlashAttn Forward] q.shape=[" << q.shape() << "], k.shape=[" + // << k.shape() << "], v.shape=[" << v.shape() << "]"; + //VLOG(10) << "[FlashAttn Forward] dropout=" << dropout + // << ", seed=" << params.seed << ", offset=" << params.offset; + //VLOG(10) << "[FlashAttn Forward] softmax_scale=" << softmax_scale + // << ", softmax_unscale=" << softmax_unscale; + //if (attn_mask.get_ptr()) { + // VLOG(10) << "[FlashAttn Forward] attn_mask.shape=[" + // << (attn_mask.get_ptr())->shape() << "]"; + //} + + //Error ctx.template Alloc(out); + + cudaStream_t stream = q.stream(); + + bool succ = flash_attn_fwd( + q.data(), + k.data(), + v.data(), + params.rng_state.data(), + out->data(), + params.return_softmax ? params.softmax->data() : nullptr, + params.softmax_lse->data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.softmax_scale, + softmax_unscale, + params.causal, + params.return_softmax, + params.is_bf16, + stream, + params.seed, + params.offset, + params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, + params.mask_dims.data()); + CheckFlashAttnStatus(succ); +} +std::vector FaFwd( + const paddle::Tensor& q, + const paddle::Tensor& k, + const paddle::Tensor& v, + const paddle::optional& fixed_seed_offset, + const paddle::optional& attn_mask, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name) { + return_softmax = false; + paddle::Tensor out = paddle::empty(q.shape(), q.type()); + // out.set_layout(q.layout()); + paddle::Tensor softmax = paddle::empty({1}, q.type()); + paddle::Tensor softmax_lse = paddle::empty({1}, q.type()); + paddle::Tensor seed_offset = paddle::empty({1}, q.type()); + auto place = q.place(); + const phi::GPUContext *ctx{nullptr}; + //Error auto ctx = phi::GPUContext(); + //Error auto ctx = new phi::GPUContext(place); + switch(q.type()){ + case paddle::DataType::FLOAT16: + FlashAttnKernel(*ctx,q,k,v, fixed_seed_offset, attn_mask, dropout, causal, return_softmax, is_test, rng_name, &out, &softmax, &softmax_lse, &seed_offset); + break; + case paddle::DataType::BFLOAT16: + FlashAttnKernel(*ctx,q,k,v, fixed_seed_offset, attn_mask, dropout, causal, return_softmax, is_test, rng_name, &out, &softmax, &softmax_lse, &seed_offset); + break; + default: + break; + // Error + } + return {out, softmax, softmax_lse, seed_offset}; +} + +std::vector> FaFwdInferShape( + std::vector q_shape, + std::vector k_shape, + std::vector v_shape, + std::vector fixed_seed_offset, + std::vector mask_shape, + float dropout, + bool causal, + bool return_softmax, + bool is_test, + const std::string& rng_name) { + return {q_shape, k_shape, v_shape, mask_shape}; +} + +PD_BUILD_OP(flash_attn_with_mask) + .Inputs({"q", "k", "v", "fixed_seed_offset","attn_mask"}) + .Outputs({"out", "softmax", "softmax_lse","seed_offset"}) + .Attrs({"dropout: float","causal:bool", "return_softmax:bool","is_test:bool","rng_name:std::string"}) + .SetKernelFn(PD_KERNEL(FaFwd)) + .SetInferShapeFn(PD_INFER_SHAPE(FaFwdInferShape)); + +template +void FlashAttnGradKernel(const Context& ctx, + const paddle::Tensor& q, + const paddle::Tensor& k, + const paddle::Tensor& v, + const paddle::Tensor& out, + const paddle::Tensor& softmax_lse, + const paddle::Tensor& seed_offset, + const paddle::optional& attn_mask, + const paddle::Tensor& dout, + float dropout, + bool causal, + paddle::Tensor* dq, + paddle::Tensor* dk, + paddle::Tensor* dv) { + void* dq_ptr = nullptr; + void* dk_ptr = nullptr; + void* dv_ptr = nullptr; + + // Error ctx.template Alloc(dq); + dq_ptr = dq->data(); + + paddle::Tensor dk_tmp; + dk_tmp = paddle::empty_like(k, q.type()); + dk_ptr = dk_tmp.data(); + + paddle::Tensor dv_tmp; + dv_tmp = paddle::empty_like(v, q.type()); + dv_ptr = dv_tmp.data(); + + const cudaStream_t stream = q.stream(); + + // q, k, v [batch_size, seq_len, num_heads, head_dim] + const auto& dims = q.shape(); + + const int64_t batch_size = dims[0]; + const int64_t seqlen_q = dims[1]; + const int64_t num_heads = dims[2]; + const int64_t head_size_og = dout.shape()[3]; + const int64_t head_size = dims[3]; + const int64_t seqlen_k = k.shape()[1]; + const int64_t num_heads_k = k.shape()[2]; + + // TODO(umiswing): add shape check + // Error PADDLE_ENFORCE_EQ( + // Error head_size_og, + // Error head_size, + // Error phi::errors::InvalidArgument( + // Error "flash_attn_bwd receive input with head_size_og == head_size")); + + const float softmax_scale = 1.0f / std::sqrt(head_size); + const float softmax_unscale = std::sqrt(head_size); + + FlashAttnBwdParamsV2 params = + FlashAttnBwdParamsV2(ctx, + batch_size, + seqlen_q, + seqlen_k, + num_heads, + num_heads_k, + head_size, + dropout, + softmax_scale, + causal, + q.dtype(), + attn_mask, + seed_offset.data()); + + // Error VLOG(10) << "[FlashAttn Forward] q.shape=[" << q.shape() << "], k.shape=[" + // Error << k.shape() << "], v.shape=[" << v.shape() << "]"; + // Error VLOG(10) << "[FlashAttn Forward] dropout=" << dropout + // Error << ", seed=" << params.seed << ", offset=" << params.offset; + // Error VLOG(10) << "[FlashAttn Forward] softmax_scale=" << softmax_scale + // Error << ", softmax_unscale=" << softmax_unscale; + // Error if (attn_mask.get_ptr()) { + // Error VLOG(10) << "[FlashAttn Backward] attn_mask.shape=[" + // Error << (attn_mask.get_ptr())->shape() << "]"; + // Error } +#ifdef PADDLE_WITH_ADVANCED + int num_splits = 1; // Error get_num_split(); +#else + int num_splits = 0; // Error get_num_split(); +#endif + + bool succ = flash_attn_bwd( + dout.data(), + q.data(), + k.data(), + v.data(), + out.data(), + params.softmax_d.data(), + softmax_lse.data(), + params.rng_state.data(), + dq_ptr, + dk_ptr, + dv_ptr, + params.dq_accum.data(), + params.batch_size, + params.max_seqlen_q, + params.max_seqlen_k, + params.seqlen_q_rounded, + params.seqlen_k_rounded, + params.num_heads, + params.num_heads_k, + params.head_size, + params.head_size_rounded, + params.dropout, + params.softmax_scale, + softmax_unscale, + params.causal, + params.is_bf16, + num_splits, + stream, + params.seed, + params.offset, + params.attn_mask_tensor ? params.attn_mask_tensor->data() : nullptr, + params.attn_mask_tensor ? params.mask_dims.data() : nullptr); + CheckFlashAttnStatus(succ); +} + +std::vector FaBwd( + const paddle::Tensor& q, + const paddle::Tensor& k, + const paddle::Tensor& v, + const paddle::Tensor& out, + const paddle::Tensor& softmax_lse, + const paddle::Tensor& seed_offset, + const paddle::optional& attn_mask, + const paddle::Tensor& dout, + float dropout, + bool causal) { + + paddle::Tensor dq = paddle::empty(q.shape(), q.type()); + paddle::Tensor dk = paddle::empty(q.shape(), q.type()); + paddle::Tensor dv = paddle::empty(q.shape(), q.type()); + const phi::GPUContext *ctx{nullptr}; + //Error auto ctx = phi::GPUContext(); + switch(q.type()){ + case paddle::DataType::FLOAT16: + FlashAttnGradKernel(*ctx,q,k,v,out, softmax_lse, seed_offset, attn_mask, dout, dropout, causal,&dq, &dk, &dv); + break; + case paddle::DataType::BFLOAT16: + FlashAttnGradKernel(*ctx,q,k,v,out, softmax_lse, seed_offset, attn_mask, dout, dropout, causal, &dq, &dk, &dv); + break; + default: + break; + // Error + } + return {dq, dk, dv}; +} + +std::vector> FaBwdInferShape( + std::vector q_shape, + std::vector k_shape, + std::vector v_shape, + std::vector out_shape, + std::vector softmax_lse, + std::vector seed_offset, + std::vector mask_shape, + std::vector dout, + float dropout, + bool causal) { + return {q_shape, k_shape, v_shape, out_shape, softmax_lse,seed_offset, mask_shape,dout}; +} + +PD_BUILD_OP(flash_attn_with_mask_grad) + .Inputs({"q", "k", "v", "out", "softmax_lse","seed_offset", "attn_mask","dout"}) + .Outputs({"dq", "dk", "dv"}) + .Attrs({"dropout: float","causal:bool"}) + .SetKernelFn(PD_KERNEL(FaBwd)) + .SetInferShapeFn(PD_INFER_SHAPE(FaBwdInferShape)); diff --git a/csrc/flash_attn/src/flash_bwd_launch_template.h b/csrc/flash_attn/src/flash_bwd_launch_template.h index 2c62e6c5797..d611b5deaee 100644 --- a/csrc/flash_attn/src/flash_bwd_launch_template.h +++ b/csrc/flash_attn/src/flash_bwd_launch_template.h @@ -64,6 +64,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, const bool is_attn_mask = params.attn_mask_ptr != nullptr; const bool is_deterministic = params.num_splits == 1; // printf("smem_size_dq_dk_dv = %d\n", smem_size_dq_dk_dv); +#ifdef PADDLE_WITH_ADVANCED BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { @@ -82,6 +83,21 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms, cudaStream_t stream, }); }); }); +#else + BOOL_SWITCH(params.is_causal, IsCausalConst, [&] { + BOOL_SWITCH(is_even_MN, IsEvenMNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + auto kernel = &flash_bwd_dq_dk_dv_loop_seqk_parallel_kernel; + if (smem_size_dq_dk_dv >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size_dq_dk_dv)); + } + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); +#endif auto kernel_dq = &flash_bwd_convert_dq_kernel; if (Kernel_traits::kSmemdQSize >= 48 * 1024) { diff --git a/csrc/flash_attn/src/flash_fwd_launch_template.h b/csrc/flash_attn/src/flash_fwd_launch_template.h index 6c638261766..ae707d0e9d4 100644 --- a/csrc/flash_attn/src/flash_fwd_launch_template.h +++ b/csrc/flash_attn/src/flash_fwd_launch_template.h @@ -36,6 +36,7 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { const bool return_softmax = params.p_ptr != nullptr; const bool is_attn_mask = params.attn_mask_ptr != nullptr; const bool is_equal_qk = (params.seqlen_q == params.seqlen_k) && (Is_causal) && (!is_attn_mask); +#ifdef PADDLE_WITH_ADVANCED BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { @@ -59,6 +60,29 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) { }); }); }); +#else + BOOL_SWITCH(is_even_N, IsEvenNConst, [&] { + BOOL_SWITCH(is_even_K, IsEvenKConst, [&] { + BOOL_SWITCH(return_softmax, ReturnSoftmaxConst, [&] { + BOOL_SWITCH(is_equal_qk, Is_equal_seq_qk, [&] { + // Will only return softmax if dropout, to reduce compilation time. + auto kernel = &flash_fwd_kernel; + // auto kernel = &flash_fwd_kernel; + if (smem_size >= 48 * 1024) { + C10_CUDA_CHECK(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + } + int ctas_per_sm; + cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::kNThreads, smem_size); + // printf("smem_size = %d, CTAs per SM = %d\n", int(smem_size), ctas_per_sm); + kernel<<>>(params); + C10_CUDA_KERNEL_LAUNCH_CHECK(); + }); + }); + }); + }); +#endif } template diff --git a/csrc/setup.py b/csrc/setup.py new file mode 100644 index 00000000000..4c22485f83c --- /dev/null +++ b/csrc/setup.py @@ -0,0 +1,107 @@ +#licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import multiprocessing +from pathlib import Path +import os + +def get_gencode_flags(): + import paddle + + prop = paddle.device.cuda.get_device_properties() + cc = prop.major * 10 + prop.minor + return ["-gencode", "arch=compute_{0},code=sm_{0}".format(cc)] + + +def run(func): + p = multiprocessing.Process(target=func) + p.start() + p.join() + + +def change_pwd(): + path = os.path.dirname(__file__) + if path: + os.chdir(path) + + +this_dir = os.path.dirname(os.path.abspath(__file__)) +def setup_fused_ln(): + from paddle.utils.cpp_extension import CUDAExtension, setup + + gencode_flags = get_gencode_flags() + change_pwd() + setup( + name="flash_attn", + ext_modules=CUDAExtension( + sources=[ + "flash_attn.cu", + "flash_attn/src/cuda_utils.cu", + "flash_attn/src/flash_fwd_hdim32_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim32_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim64_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim64_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim96_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim96_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim128_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim128_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim160_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim160_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim192_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim192_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim224_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim224_bf16_sm80.cu", + "flash_attn/src/flash_fwd_hdim256_fp16_sm80.cu", + "flash_attn/src/flash_fwd_hdim256_bf16_sm80.cu", + "flash_attn/src/flash_bwd_hdim32_fp16_sm80.cu", + "flash_attn/src/flash_bwd_hdim32_bf16_sm80.cu", + "flash_attn/src/flash_bwd_hdim64_fp16_sm80.cu", + "flash_attn/src/flash_bwd_hdim64_bf16_sm80.cu", + "flash_attn/src/flash_bwd_hdim96_fp16_sm80.cu", + "flash_attn/src/flash_bwd_hdim96_bf16_sm80.cu", + "flash_attn/src/flash_bwd_hdim128_fp16_sm80.cu", + "flash_attn/src/flash_bwd_hdim128_bf16_sm80.cu", + "flash_attn/src/flash_bwd_hdim160_fp16_sm80.cu", + "flash_attn/src/flash_bwd_hdim160_bf16_sm80.cu", + "flash_attn/src/flash_bwd_hdim192_fp16_sm80.cu", + "flash_attn/src/flash_bwd_hdim192_bf16_sm80.cu", + "flash_attn/src/flash_bwd_hdim224_fp16_sm80.cu", + "flash_attn/src/flash_bwd_hdim224_bf16_sm80.cu", + "flash_attn/src/flash_bwd_hdim256_fp16_sm80.cu", + "flash_attn/src/flash_bwd_hdim256_bf16_sm80.cu", + ], + extra_compile_args={ + "cxx": ["-O3", "-std=c++17"] + gencode_flags, + "nvcc": [ + "-O3", + "-std=c++17", + "-U__CUDA_NO_HALF_OPERATORS__", + "-U__CUDA_NO_HALF_CONVERSIONS__", + "-U__CUDA_NO_HALF2_OPERATORS__", + "-U__CUDA_NO_BFLOAT16_CONVERSIONS__", + "--expt-relaxed-constexpr", + "--expt-extended-lambda", + "--use_fast_math", + "--ptxas-options=-v", + # "--ptxas-options=-O2", + "-lineinfo" + ] + }, + include_dirs=[ + "flash_attn", + "flash_attn/src", + "cutlass/include", + ], + ), + ) + + +run(setup_fused_ln)