diff --git a/excuter/op-mem-cuda/src/deepx/dtype_cuda.hpp b/excuter/op-mem-cuda/src/deepx/dtype_cuda.hpp index 57ef5eb..98f212c 100644 --- a/excuter/op-mem-cuda/src/deepx/dtype_cuda.hpp +++ b/excuter/op-mem-cuda/src/deepx/dtype_cuda.hpp @@ -3,6 +3,7 @@ #include #include +#include #include "deepx/dtype.hpp" @@ -34,6 +35,27 @@ namespace deepx else return Precision::Any; } + + + template <> + struct to_tensor_type> { + using type = nv_bfloat16; + }; + + template <> + struct to_tensor_type> { + using type = half; + }; + + template <> + struct to_tensor_type> { + using type = __nv_fp8_e5m2; + }; + + template <> + struct to_tensor_type> { + using type = __nv_fp8_e4m3; + }; } #endif // DEEPX_DTYPE_CUDA_HPP