From 784210dee3b22b3fe492be4e2e9840dd112c3540 Mon Sep 17 00:00:00 2001 From: sunnycase Date: Thu, 25 Sep 2025 02:27:24 +0000 Subject: [PATCH 01/13] Add cuda test --- CMakeLists.txt | 2 +- cmake/compile_flags.cmake | 7 +- ntt/include/nncase/ntt/dimension.h | 23 +- ntt/include/nncase/ntt/kernels/reduce.h | 12 +- ntt/include/nncase/ntt/kernels/where.h | 2 +- ntt/include/nncase/ntt/padding.h | 6 +- ntt/include/nncase/ntt/primitive_ops.h | 2 +- ntt/include/nncase/ntt/shape.h | 14 +- ntt/include/nncase/ntt/ukernels/u_transpose.h | 2 +- ntt/test/ctest/test_ntt_cuda.cpp | 0 src/Native/include/nncase/runtime/model.h | 4 +- .../include/nncase/runtime/simple_types.h | 2 +- src/Native/src/CMakeLists.txt | 2 + src/Native/src/cuda_test.cu | 1438 +++++++++++++++++ 14 files changed, 1480 insertions(+), 36 deletions(-) create mode 100644 ntt/test/ctest/test_ntt_cuda.cpp create mode 100644 src/Native/src/cuda_test.cu diff --git a/CMakeLists.txt b/CMakeLists.txt index 0193cbf70..833ce4647 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,7 +37,7 @@ endif() project(nncase VERSION ${NNCASE_VERSION} - LANGUAGES C CXX ASM) + LANGUAGES C CXX ASM CUDA) option(DOTNET_INIT_FOR_CONFIG "Initialize dotnet from runtimeconfig" OFF) option(BUILD_PYTHON_BINDING "Build python binding" ON) diff --git a/cmake/compile_flags.cmake b/cmake/compile_flags.cmake index b5d7a36c4..c99e28c14 100644 --- a/cmake/compile_flags.cmake +++ b/cmake/compile_flags.cmake @@ -4,7 +4,7 @@ if (MSVC) set(PYBIND11_CPP_STANDARD "/std:c++latest") else() add_compile_options(-fvisibility=hidden) - add_compile_options(-Wall -Wextra -pedantic -Werror -Wno-multichar -Wno-missing-field-initializers -Wno-unused-function -Wno-type-limits -Wno-unused-local-typedefs -Wno-sign-compare) + add_compile_options(-Wall -Wextra -Wno-missing-field-initializers -Wno-unused-function -Wno-type-limits -Wno-unused-local-typedefs -Wno-sign-compare) if (APPLE) add_compile_options(-Wno-four-char-constants -Wno-sometimes-uninitialized -Wno-deprecated -Wno-braced-scalar-init) elseif (CMAKE_CXX_COMPILER_ID STREQUAL "Clang") @@ -15,6 +15,11 @@ else() endif() endif() +if (CMAKE_CUDA_COMPILER) + message(STATUS "Configuring for CUDA") + add_compile_options($<$:--expt-relaxed-constexpr>) +endif() + if(${CMAKE_SYSTEM_PROCESSOR} MATCHES "(x86)|(X86)|(amd64)|(AMD64)|(x86_64)|(X86_64)") if (MSVC) diff --git a/ntt/include/nncase/ntt/dimension.h b/ntt/include/nncase/ntt/dimension.h index d77e7f5c8..aa9973d09 100644 --- a/ntt/include/nncase/ntt/dimension.h +++ b/ntt/include/nncase/ntt/dimension.h @@ -16,6 +16,7 @@ #include "primitive_ops.h" #include "tensor_traits.h" #include +#include #include #include #include @@ -60,7 +61,7 @@ template struct char_literal { }; } // namespace detail -template inline constexpr auto operator"" _dim() { +template inline constexpr auto operator""_dim() { constexpr auto value = detail::char_literal::to_int; return fixed_dim_v; } @@ -184,16 +185,15 @@ constexpr auto positive_index(const TIndex &index, } else { return index; } + } else if constexpr (std::unsigned_integral) { + return index; } else { return index < 0 ? index + dim : index; } } -namespace detail { -template struct dim_where_impl; - -template -struct dim_where_impl { +namespace ops { +template struct where { constexpr dim_t operator()(const Cond &cond, const T &true_dim, const F &false_dim) const noexcept { return cond ? dim_value(true_dim) : dim_value(false_dim); @@ -201,7 +201,7 @@ struct dim_where_impl { }; template -struct dim_where_impl, T, F> { +struct where, T, F> { constexpr auto operator()(const std::integral_constant &, [[maybe_unused]] const T &true_dim, @@ -213,12 +213,5 @@ struct dim_where_impl, T, F> { } } }; -} // namespace detail - -template -constexpr auto where(const Cond &cond, const T &true_dim, - const F &false_dim) noexcept { - detail::dim_where_impl impl; - return impl(cond, true_dim, false_dim); -} +} // namespace ops } // namespace nncase::ntt diff --git a/ntt/include/nncase/ntt/kernels/reduce.h b/ntt/include/nncase/ntt/kernels/reduce.h index 25d4d786d..2c25e0eea 100644 --- a/ntt/include/nncase/ntt/kernels/reduce.h +++ b/ntt/include/nncase/ntt/kernels/reduce.h @@ -14,11 +14,10 @@ */ #pragma once #include "../primitive_ops.h" +#include "../shape.h" #include "../shape_infer/reduce.h" #include "../ukernels.h" #include "../utility.h" -#include "nncase/ntt/dimension.h" -#include "nncase/ntt/tensor_traits.h" #include namespace nncase::ntt { @@ -152,11 +151,10 @@ class reduce_impl { } } - template - constexpr void - apply_contiguous_reduce(dynamic_shape_t &index, - size_t conti_dims, const TSubIn &input, - TInElem &reduced_in) { + template + constexpr void apply_contiguous_reduce(TIndex &index, size_t conti_dims, + const TSubIn &input, + TInElem &reduced_in) { const auto outer_dims = TSubIn::rank() - conti_dims; const auto axis_v = fixed_dim_v; if (Axis >= outer_dims) { diff --git a/ntt/include/nncase/ntt/kernels/where.h b/ntt/include/nncase/ntt/kernels/where.h index 8e4ccadd8..e2604528a 100644 --- a/ntt/include/nncase/ntt/kernels/where.h +++ b/ntt/include/nncase/ntt/kernels/where.h @@ -51,7 +51,7 @@ class where_impl : public elementwise_impl, } // namespace detail template -void where(const TCond &cond, const TX &x, const TY &y, TOut &&output) { +constexpr void where(const TCond &cond, const TX &x, const TY &y, TOut &&output) { detail::where_impl>()(cond, x, y, output); } } // namespace nncase::ntt diff --git a/ntt/include/nncase/ntt/padding.h b/ntt/include/nncase/ntt/padding.h index 2e5fab05d..25d8c7229 100644 --- a/ntt/include/nncase/ntt/padding.h +++ b/ntt/include/nncase/ntt/padding.h @@ -184,9 +184,11 @@ template struct dynamic_paddings_type_impl; template struct dynamic_paddings_type_impl> { - template using elem_type = dim_t; + template struct elem_type { + using type = padding_t; + }; - using type = paddings_t...>; + using type = paddings_t::type...>; }; template diff --git a/ntt/include/nncase/ntt/primitive_ops.h b/ntt/include/nncase/ntt/primitive_ops.h index 430e9406f..252daaed2 100644 --- a/ntt/include/nncase/ntt/primitive_ops.h +++ b/ntt/include/nncase/ntt/primitive_ops.h @@ -486,7 +486,7 @@ constexpr T1 clamp(const T1 &v, const T2 &min, const T2 &max) noexcept { return ops::clamp()(v, min, max); } -template +template constexpr auto where(const TCond &cond, const TX &x, const TY &y) noexcept { return ops::where()(cond, x, y); } diff --git a/ntt/include/nncase/ntt/shape.h b/ntt/include/nncase/ntt/shape.h index a667f36c5..af27529f6 100644 --- a/ntt/include/nncase/ntt/shape.h +++ b/ntt/include/nncase/ntt/shape.h @@ -41,9 +41,11 @@ struct dynamic_dims_type_impl; template