Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions executor/op-mem-cuda/src/deepx/dtype_cuda.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <type_traits>

#include "deepx/dtype.hpp"

Expand Down Expand Up @@ -56,6 +57,45 @@ namespace deepx
struct to_tensor_type<PrecisionWrapper<Precision::Float8E4M3>> {
using type = __nv_fp8_e4m3;
};



template <typename T>
struct fp8_format_map;

template <>
struct fp8_format_map<__nv_fp8_e5m2> {
static constexpr __nv_fp8_interpretation_t value = __NV_E5M2;
};

template <>
struct fp8_format_map<__nv_fp8_e4m3> {
static constexpr __nv_fp8_interpretation_t value = __NV_E4M3;
};

template<typename T>
struct is_fp8 : std::false_type {}; // 默认 false

template<> struct is_fp8<__nv_fp8_e4m3> : std::true_type {};
template<> struct is_fp8<__nv_fp8_e5m2> : std::true_type {};


template <typename T>
inline constexpr bool is_fp8_v = is_fp8<T>::value;

template <typename T>
struct to_half {
static __host__ __device__ __half convert(T a) {
return __nv_cvt_fp8_to_halfraw(static_cast<__nv_fp8_storage_t>(a), fp8_format_map<T>::value);
}
};

template <typename T>
struct to_fp8 {
static __host__ __device__ T convert(half a) {
return static_cast<T>(__nv_cvt_halfraw_to_fp8(a, __NV_SATFINITE, fp8_format_map<T>::value));
}
};
}

#endif // DEEPX_DTYPE_CUDA_HPP
23 changes: 6 additions & 17 deletions executor/op-mem-cuda/src/deepx/tensorfunc/cuda_math.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cublas_v2.h>
#include "deepx/dtype_cuda.hpp"

namespace deepx::tensorfunc
{
Expand Down Expand Up @@ -38,24 +39,12 @@ namespace deepx::tensorfunc
*out = hsqrt(*a);
}

template <>
__device__ __forceinline__ void deepx_sqrt<__nv_fp8_e4m3>(const __nv_fp8_e4m3 *a, __nv_fp8_e4m3 *out)
template <typename T, std::enable_if_t<is_fp8_v<T>> = 0>
__device__ __forceinline__ void deepx_sqrt(const T *a, T *out)
{
__half input_fp16 = __nv_cvt_fp8_to_halfraw(static_cast<__nv_fp8_storage_t>(*a), __NV_E4M3);
__half result_fp16 = hsqrt(input_fp16); // CUDA 内置半精度平方根
*out = static_cast<__nv_fp8_e4m3>(__nv_cvt_halfraw_to_fp8(result_fp16, __NV_SATFINITE, __NV_E4M3));
}

template <>
__device__ __forceinline__ void deepx_sqrt<__nv_fp8_e5m2>(const __nv_fp8_e5m2 *a, __nv_fp8_e5m2 *out)
{
__half input_fp16 = __nv_cvt_fp8_to_halfraw(static_cast<__nv_fp8_storage_t>(*a), __NV_E5M2);

// 2. 执行平方根
__half result_fp16 = hsqrt(input_fp16);

// 3. 转回 FP8 → E5M2 格式
*out =static_cast<__nv_fp8_e5m2>(__nv_cvt_halfraw_to_fp8(result_fp16, __NV_SATFINITE, __NV_E5M2));
__half input_half = to_half<T>::convert(*a);
__half result_half = hsqrt(input_half); // CUDA 内置半精度平方根
*out = to_fp8<T>::convert(result_half);
}


Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
#ifndef DEEPX_TENSORFUNC_ELEMENTWISE_MIAO_BYTE_COMPARE_CU
#define DEEPX_TENSORFUNC_ELEMENTWISE_MIAO_BYTE_COMPARE_CU

#include <cuda_fp8.h>
#include "deepx/tensorfunc/cuda.hpp"
#include "deepx/tensorfunc/authors.hpp"
#include "deepx/tensorfunc/vector_cuda.cuh"
#include "deepx/dtype_cuda.hpp"
namespace deepx::tensorfunc
{
template <typename T>
template <typename T, std::enable_if_t<!is_fp8_v<T>, int> = 0>
__global__ void max_kernel(const T *A, const T *B, T *C, const int size)
{
int stride = blockDim.x * gridDim.x;
Expand All @@ -16,6 +18,20 @@ namespace deepx::tensorfunc
}
}

template <typename T, std::enable_if_t<is_fp8_v<T>, int> = 0>
__global__ void max_kernel(const T *A, const T *B, T *C, const int size)
{
int stride = blockDim.x * gridDim.x;
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride)
{
__half temp_a = to_half<T>::convert(A[idx]);
__half temp_b = to_half<T>::convert(B[idx]);
__half temp_c = temp_a > temp_b ? temp_a : temp_b;
C[idx] = to_fp8<T>::convert(temp_c);
}
}


template <typename T>
void launch_max(const T *A, const T *B, T *C, const int size)
{
Expand All @@ -32,8 +48,10 @@ namespace deepx::tensorfunc
template void launch_max<int32_t>(const int32_t *A, const int32_t *B, int32_t *C, const int size);
template void launch_max<int16_t>(const int16_t *A, const int16_t *B, int16_t *C, const int size);
template void launch_max<int8_t>(const int8_t *A, const int8_t *B, int8_t *C, const int size);
template void launch_max<__nv_fp8_e4m3>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 *B, __nv_fp8_e4m3 *C, const int size);
template void launch_max<__nv_fp8_e5m2>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 *B, __nv_fp8_e5m2 *C, const int size);

template <typename T>
template <typename T, std::enable_if_t<!is_fp8_v<T>, int> = 0>
__global__ void maxscalar_kernel(const T *A, const T scalar, T *C, const int size)
{
int stride = blockDim.x * gridDim.x;
Expand All @@ -43,6 +61,19 @@ namespace deepx::tensorfunc
}
}

template <typename T, std::enable_if_t<is_fp8_v<T>, int> = 0>
__global__ void maxscalar_kernel(const T *A, const T scalar, T *C, const int size)
{
int stride = blockDim.x * gridDim.x;
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride)
{
__half temp_a = to_half<T>::convert(A[idx]);
__half temp_scalar = to_half<T>::convert(scalar);
__half temp_c = temp_a > temp_scalar ? temp_a : temp_scalar;
C[idx] = to_fp8<T>::convert(temp_c);
}
}

template <typename T>
void launch_maxscalar(const T *A, const T scalar, T *C, const int size)
{
Expand All @@ -59,8 +90,10 @@ namespace deepx::tensorfunc
template void launch_maxscalar<int32_t>(const int32_t *A, const int32_t scalar, int32_t *C, const int size);
template void launch_maxscalar<int16_t>(const int16_t *A, const int16_t scalar, int16_t *C, const int size);
template void launch_maxscalar<int8_t>(const int8_t *A, const int8_t scalar, int8_t *C, const int size);
template void launch_maxscalar<__nv_fp8_e4m3>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 scalar, __nv_fp8_e4m3 *C, const int size);
template void launch_maxscalar<__nv_fp8_e5m2>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 scalar, __nv_fp8_e5m2 *C, const int size);

template <typename T>
template <typename T, std::enable_if_t<!is_fp8_v<T>, int> = 0>
__global__ void min_kernel(const T *A, const T *B, T *C, const int size)
{
int stride = blockDim.x * gridDim.x;
Expand All @@ -70,6 +103,20 @@ namespace deepx::tensorfunc
}
}


template <typename T, std::enable_if_t<is_fp8_v<T>, int> = 0>
__global__ void min_kernel(const T *A, const T *B, T *C, const int size)
{
int stride = blockDim.x * gridDim.x;
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride)
{
__half temp_a = to_half<T>::convert(A[idx]);
__half temp_b = to_half<T>::convert(B[idx]);
__half temp_c = temp_a < temp_b ? temp_a : temp_b;
C[idx] = to_fp8<T>::convert(temp_c);
}
}

template <typename T>
void launch_min(const T *A, const T *B, T *C, const int size)
{
Expand All @@ -86,8 +133,10 @@ namespace deepx::tensorfunc
template void launch_min<int32_t>(const int32_t *A, const int32_t *B, int32_t *C, const int size);
template void launch_min<int16_t>(const int16_t *A, const int16_t *B, int16_t *C, const int size);
template void launch_min<int8_t>(const int8_t *A, const int8_t *B, int8_t *C, const int size);
template void launch_min<__nv_fp8_e4m3>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 *B, __nv_fp8_e4m3 *C, const int size);
template void launch_min<__nv_fp8_e5m2>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 *B, __nv_fp8_e5m2 *C, const int size);

template <typename T>
template <typename T, std::enable_if_t<!is_fp8_v<T>, int> = 0>
__global__ void minscalar_kernel(const T *A, const T scalar, T *C, const int size)
{
int stride = blockDim.x * gridDim.x;
Expand All @@ -97,6 +146,19 @@ namespace deepx::tensorfunc
}
}

template <typename T, std::enable_if_t<is_fp8_v<T>, int> = 0>
__global__ void minscalar_kernel(const T *A, const T scalar, T *C, const int size)
{
int stride = blockDim.x * gridDim.x;
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride)
{
__half temp_a = to_half<T>::convert(A[idx]);
__half temp_scalar = to_half<T>::convert(scalar);
__half temp_c = temp_a < temp_scalar ? temp_a : temp_scalar;
C[idx] = to_fp8<T>::convert(temp_c);
}
}

template <typename T>
void launch_minscalar(const T *A, const T scalar, T *C, const int size)
{
Expand All @@ -113,9 +175,11 @@ namespace deepx::tensorfunc
template void launch_minscalar<int32_t>(const int32_t *A, const int32_t scalar, int32_t *C, const int size);
template void launch_minscalar<int16_t>(const int16_t *A, const int16_t scalar, int16_t *C, const int size);
template void launch_minscalar<int8_t>(const int8_t *A, const int8_t scalar, int8_t *C, const int size);
template void launch_minscalar<__nv_fp8_e4m3>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 scalar, __nv_fp8_e4m3 *C, const int size);
template void launch_minscalar<__nv_fp8_e5m2>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 scalar, __nv_fp8_e5m2 *C, const int size);

// equal
template <typename T,typename MaskT>
template <typename T,typename MaskT, std::enable_if_t<!is_fp8_v<T>, int> = 0>
__global__ void equalwithepsilon_kernel(const T *A, const T *B, const float epsilon, MaskT *mask, const int size)
{
int stride = blockDim.x * gridDim.x;
Expand All @@ -133,7 +197,28 @@ namespace deepx::tensorfunc
}
}

template <typename T,typename MaskT>
// equal
template <typename T, typename MaskT, std::enable_if_t<is_fp8_v<T>, int> = 0>
__global__ void equalwithepsilon_kernel(const T *A, const T *B, const float epsilon, MaskT *mask, const int size)
{
int stride = blockDim.x * gridDim.x;
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride)
{
float diff = fabsf(static_cast<float>(to_half<T>::convert(A[idx])) - static_cast<float>(to_half<T>::convert(B[idx])));
if (diff < epsilon)
{
mask[idx] = 1;
}
else
{
mask[idx] = 0;
}
}
}



template <typename T,typename MaskT, std::enable_if_t<!is_fp8_v<T>, int> = 0>
__global__ void equal_kernel(const T *A, const T *B, MaskT *mask, const int size)
{
int stride = blockDim.x * gridDim.x;
Expand All @@ -143,6 +228,16 @@ namespace deepx::tensorfunc
}
}

template <typename T,typename MaskT, std::enable_if_t<is_fp8_v<T>, int> = 0>
__global__ void equal_kernel(const T *A, const T *B, MaskT *mask, const int size)
{
int stride = blockDim.x * gridDim.x;
for (int idx = blockIdx.x * blockDim.x + threadIdx.x; idx < size; idx += stride)
{
mask[idx] = (to_half<T>::convert(A[idx]) == to_half<T>::convert(B[idx]));
}
}

template <typename T,typename MaskT>
void launch_equal(const T *A, const T *B, const float epsilon, MaskT *mask, const int size)
{
Expand All @@ -166,6 +261,8 @@ namespace deepx::tensorfunc
template void launch_equal<int32_t,bool>(const int32_t *A, const int32_t *B, const float epsilon, bool *mask, const int size);
template void launch_equal<int16_t,bool>(const int16_t *A, const int16_t *B, const float epsilon, bool *mask, const int size);
template void launch_equal<int8_t,bool>(const int8_t *A, const int8_t *B, const float epsilon, bool *mask, const int size);
template void launch_equal<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 *B, const float epsilon, bool *mask, const int size);
template void launch_equal<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 *B, const float epsilon, bool *mask, const int size);

// equalscalar
template <typename T,typename MaskT>
Expand Down Expand Up @@ -219,6 +316,8 @@ namespace deepx::tensorfunc
template void launch_equalscalar<int32_t,bool>(const int32_t *A, const int32_t scalar, const float epsilon, bool *mask, const int size);
template void launch_equalscalar<int16_t,bool>(const int16_t *A, const int16_t scalar, const float epsilon, bool *mask, const int size);
template void launch_equalscalar<int8_t,bool>(const int8_t *A, const int8_t scalar, const float epsilon, bool *mask, const int size);
// template void launch_equalscalar<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 scalar, const float epsilon, bool *mask, const int size);
// template void launch_equalscalar<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 scalar, const float epsilon, bool *mask, const int size);

// not equal
template <typename T,typename MaskT>
Expand Down Expand Up @@ -272,6 +371,8 @@ namespace deepx::tensorfunc
template void launch_notequal<int32_t,bool>(const int32_t *A, const int32_t *B, const float epsilon, bool *mask, const int size);
template void launch_notequal<int16_t,bool>(const int16_t *A, const int16_t *B, const float epsilon, bool *mask, const int size);
template void launch_notequal<int8_t,bool>(const int8_t *A, const int8_t *B, const float epsilon, bool *mask, const int size);
// template void launch_notequal<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 *B, const float epsilon, bool *mask, const int size);
// template void launch_notequal<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 *B, const float epsilon, bool *mask, const int size);

// notequalscalar
template <typename T,typename MaskT>
Expand Down Expand Up @@ -325,6 +426,8 @@ namespace deepx::tensorfunc
template void launch_notequalscalar<int32_t,bool>(const int32_t *A, const int32_t scalar, const float epsilon, bool *mask, const int size);
template void launch_notequalscalar<int16_t,bool>(const int16_t *A, const int16_t scalar, const float epsilon, bool *mask, const int size);
template void launch_notequalscalar<int8_t,bool>(const int8_t *A, const int8_t scalar, const float epsilon, bool *mask, const int size);
// template void launch_notequalscalar<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 scalar, const float epsilon, bool *mask, const int size);
// template void launch_notequalscalar<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 scalar, const float epsilon, bool *mask, const int size);

// less
template <typename T,typename MaskT>
Expand Down Expand Up @@ -353,6 +456,8 @@ namespace deepx::tensorfunc
template void launch_less<int32_t,bool>(const int32_t *A, const int32_t *B, bool *mask, const int size);
template void launch_less<int16_t,bool>(const int16_t *A, const int16_t *B, bool *mask, const int size);
template void launch_less<int8_t,bool>(const int8_t *A, const int8_t *B, bool *mask, const int size);
// template void launch_less<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 *B, bool *mask, const int size);
// template void launch_less<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 *B, bool *mask, const int size);

// lessscalar

Expand Down Expand Up @@ -382,7 +487,9 @@ namespace deepx::tensorfunc
template void launch_lessscalar<int32_t,bool>(const int32_t *A, const int32_t scalar, bool *mask, const int size);
template void launch_lessscalar<int16_t,bool>(const int16_t *A, const int16_t scalar, bool *mask, const int size);
template void launch_lessscalar<int8_t,bool>(const int8_t *A, const int8_t scalar, bool *mask, const int size);

// template void launch_lessscalar<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 scalar, bool *mask, const int size);
// template void launch_lessscalar<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 scalar, bool *mask, const int size);

// greater
template <typename T,typename MaskT>
__global__ void greater_kernel(const T *A, const T *B, MaskT *mask, const int size)
Expand Down Expand Up @@ -410,6 +517,8 @@ namespace deepx::tensorfunc
template void launch_greater<int32_t,bool>(const int32_t *A, const int32_t *B, bool *mask, const int size);
template void launch_greater<int16_t,bool>(const int16_t *A, const int16_t *B, bool *mask, const int size);
template void launch_greater<int8_t,bool>(const int8_t *A, const int8_t *B, bool *mask, const int size);
// template void launch_greater<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 *B, bool *mask, const int size);
// template void launch_greater<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 *B, bool *mask, const int size);

// greaterscalar
template <typename T,typename MaskT>
Expand Down Expand Up @@ -438,6 +547,8 @@ namespace deepx::tensorfunc
template void launch_greaterscalar<int32_t,bool>(const int32_t *A, const int32_t scalar, bool *mask, const int size);
template void launch_greaterscalar<int16_t,bool>(const int16_t *A, const int16_t scalar, bool *mask, const int size);
template void launch_greaterscalar<int8_t,bool>(const int8_t *A, const int8_t scalar, bool *mask, const int size);
// template void launch_greaterscalar<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 *A, const __nv_fp8_e4m3 scalar, bool *mask, const int size);
// template void launch_greaterscalar<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 *A, const __nv_fp8_e5m2 scalar, bool *mask, const int size);

// switch
template <typename T,typename casesT>
Expand Down Expand Up @@ -476,7 +587,9 @@ namespace deepx::tensorfunc
template void launch_switch<int16_t,int32_t>(const int16_t **tensorsdata, const int numTensors, const int32_t *cases, int16_t *C, const int size);
template void launch_switch<int8_t,int32_t>(const int8_t **tensorsdata, const int numTensors, const int32_t *cases, int8_t *C, const int size);
template void launch_switch<bool,int32_t>(const bool **tensorsdata, const int numTensors, const int32_t *cases, bool *C, const int size);

// template void launch_switch<__nv_fp8_e4m3,int32_t>(const __nv_fp8_e4m3 **tensorsdata, const int numTensors, const int32_t *cases, __nv_fp8_e4m3 *C, const int size);
// template void launch_switch<__nv_fp8_e5m2,int32_t>(const __nv_fp8_e5m2 **tensorsdata, const int numTensors, const int32_t *cases, __nv_fp8_e5m2 *C, const int size);

template void launch_switch<double,bool>(const double **tensorsdata, const int numTensors, const bool *cases, double *C, const int size);
template void launch_switch<float,bool>(const float **tensorsdata, const int numTensors, const bool *cases, float *C, const int size);
template void launch_switch<nv_bfloat16,bool>(const nv_bfloat16 **tensorsdata, const int numTensors, const bool *cases, nv_bfloat16 *C, const int size);
Expand All @@ -486,6 +599,7 @@ namespace deepx::tensorfunc
template void launch_switch<int16_t,bool>(const int16_t **tensorsdata, const int numTensors, const bool *cases, int16_t *C, const int size);
template void launch_switch<int8_t,bool>(const int8_t **tensorsdata, const int numTensors, const bool *cases, int8_t *C, const int size);
template void launch_switch<bool,bool>(const bool **tensorsdata, const int numTensors, const bool *cases, bool *C, const int size);

// template void launch_switch<__nv_fp8_e4m3,bool>(const __nv_fp8_e4m3 **tensorsdata, const int numTensors, const bool *cases, __nv_fp8_e4m3 *C, const int size);
// template void launch_switch<__nv_fp8_e5m2,bool>(const __nv_fp8_e5m2 **tensorsdata, const int numTensors, const bool *cases, __nv_fp8_e5m2 *C, const int size);
}
#endif // DEEPX_TENSORFUNC_ELEMENTWISE_MIAO_BYTE_COMPARE_CU
Loading